diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index db37a33..3caec81 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -107,7 +107,6 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): sde_sample_freq: int = -1, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = {}, verbose: int = 0, seed: Optional[int] = None, @@ -141,7 +140,6 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): projection, WassersteinProjectionLayer)}, verbose=verbose, device=device, - create_eval_env=create_eval_env, seed=seed, _init_setup_model=False, supported_action_spaces=( @@ -428,11 +426,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, tb_log_name: str = "PPO", - eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, ) -> "PPO": @@ -440,10 +434,6 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, - eval_env=eval_env, - eval_freq=eval_freq, - n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, - eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, ) diff --git a/metastable_baselines/sac/sac.py b/metastable_baselines/sac/sac.py index 202d432..f733aec 100644 --- a/metastable_baselines/sac/sac.py +++ b/metastable_baselines/sac/sac.py @@ -121,7 +121,6 @@ class SAC(OffPolicyAlgorithm): sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, action_coef: float = 0.0, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, @@ -153,7 +152,6 @@ class SAC(OffPolicyAlgorithm): tensorboard_log=tensorboard_log, verbose=verbose, device=device, - create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq,