diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 16e4017..5650c71 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -166,7 +166,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): t + 1 + self.current_traj_steps): break - infos.update({k: v[:t] for k, v in infos.items()}) + infos.update({k: v[:t + 1] for k, v in infos.items()}) self.current_traj_steps += t + 1 if self.verbose >= 2: diff --git a/fancy_gym/meta/__init__.py b/fancy_gym/meta/__init__.py index 4fb23b2..98935dc 100644 --- a/fancy_gym/meta/__init__.py +++ b/fancy_gym/meta/__init__.py @@ -13,7 +13,8 @@ DEFAULT_BB_DICT_ProMP = { "name": 'EnvName', "wrappers": [], "trajectory_generator_kwargs": { - 'trajectory_generator_type': 'promp' + 'trajectory_generator_type': 'promp', + 'weights_scale': 10, }, "phase_generator_kwargs": { 'phase_generator_type': 'linear' @@ -32,17 +33,22 @@ DEFAULT_BB_DICT_ProDMP = { "name": 'EnvName', "wrappers": [], "trajectory_generator_kwargs": { - 'trajectory_generator_type': 'prodmp' + 'trajectory_generator_type': 'prodmp', + 'auto_scale_basis': True, + 'weights_scale': 10, + 'goal_scale': 0. }, "phase_generator_kwargs": { - 'phase_generator_type': 'exp' + 'phase_generator_type': 'exp', + # 'alpha_phase' : 3, }, "controller_kwargs": { 'controller_type': 'metaworld', }, "basis_generator_kwargs": { 'basis_generator_type': 'prodmp', - 'num_basis': 5 + 'num_basis': 3, + 'alpha': 10 } } @@ -148,7 +154,6 @@ for _task in _goal_and_object_change_envs: ) ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) - _goal_and_endeffector_change_envs = ["basketball-v2"] for _task in _goal_and_endeffector_change_envs: task_id_split = _task.split("-") diff --git a/test/test_metaworld_envs.py b/test/test_metaworld_envs.py index 768958d..ed300f4 100644 --- a/test/test_metaworld_envs.py +++ b/test/test_metaworld_envs.py @@ -9,7 +9,6 @@ from test.utils import run_env, run_env_determinism METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()] METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) -print(METAWORLD_MP_IDS) SEED = 1