Movement Primitives Examples

  1import gymnasium as gym
  2import fancy_gym
  3
  4
  5def example_mp(env_name="fancy_ProMP/HoleReacher-v0", seed=1, iterations=1, render=True):
  6    """
  7    Example for running a black box based environment, which is already registered
  8    Args:
  9        env_name: Black box env_id
 10        seed: seed for deterministic behaviour
 11        iterations: Number of rollout steps to run
 12        render: Render the episode
 13
 14    Returns:
 15
 16    """
 17    # Equivalent to gym, we have a make function which can be used to create environments.
 18    # It takes care of seeding and enables the use of a variety of external environments using the gym interface.
 19    env = gym.make(env_name, render_mode='human' if render else None)
 20
 21    returns = 0
 22    # env.render(mode=None)
 23    obs = env.reset(seed=seed)
 24
 25    # number of samples/full trajectories (multiple environment steps)
 26    for i in range(iterations):
 27
 28        if render and i % 1 == 0:
 29            env.render()
 30
 31        # Now the action space is not the raw action but the parametrization of the trajectory generator,
 32        # such as a ProMP
 33        ac = env.action_space.sample()
 34        # This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the
 35        # full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal
 36        # to the return of a trajectory. Default is the sum over the step-wise rewards.
 37        obs, reward, terminated, truncated, info = env.step(ac)
 38        # Aggregated returns
 39        returns += reward
 40
 41        if terminated or truncated:
 42            print(reward)
 43            obs = env.reset()
 44    env.close()
 45
 46
 47def example_custom_mp(env_name="fancy_ProMP/Reacher5d-v0", seed=1, iterations=1, render=True):
 48    """
 49    Example for running a custom movement primitive based environments.
 50    Our already registered environments follow the same structure.
 51    Hence, this also allows to adjust hyperparameters of the movement primitives.
 52    Yet, we recommend the method above if you are just interested in changing those parameters for existing tasks.
 53    We appreciate PRs for custom environments (especially MP wrappers of existing tasks) 
 54    for our repo: https://github.com/ALRhub/fancy_gym/
 55    Args:
 56        seed: seed
 57        iterations: Number of rollout steps to run
 58        render: Render the episode
 59
 60    Returns:
 61
 62    """
 63    # Changing the arguments of the black box env is possible by providing them to gym through mp_config_override.
 64    # E.g. here for way to many basis functions
 65    env = gym.make(env_name, seed, mp_config_override={'basis_generator_kwargs': {'num_basis': 1000}}, render_mode='human' if render else None)
 66
 67    returns = 0
 68    obs = env.reset()
 69
 70    # This time rendering every trajectory
 71    if render:
 72        env.render()
 73
 74    # number of samples/full trajectories (multiple environment steps)
 75    for i in range(iterations):
 76        ac = env.action_space.sample()
 77        obs, reward, terminated, truncated, info = env.step(ac)
 78        returns += reward
 79
 80        if terminated or truncated:
 81            print(i, reward)
 82            obs = env.reset()
 83
 84    env.close()
 85    return obs
 86
 87class Custom_MPWrapper(fancy_gym.envs.mujoco.reacher.MPWrapper):
 88    mp_config = {
 89        'ProMP': {
 90                'trajectory_generator_kwargs':  {
 91                    'trajectory_generator_type': 'promp',
 92                    'weights_scale': 2
 93                },
 94                'phase_generator_kwargs': {
 95                    'phase_generator_type': 'linear'
 96                },
 97                'controller_kwargs': {
 98                    'controller_type': 'velocity'
 99                },
100                'basis_generator_kwargs': {
101                    'basis_generator_type': 'zero_rbf',
102                    'num_basis': 5,
103                    'num_basis_zero_start': 1
104                }
105        },
106        'DMP': {
107            'trajectory_generator_kwargs': {
108                'trajectory_generator_type': 'dmp',
109                'weights_scale': 500
110            },
111            'phase_generator_kwargs': {
112                'phase_generator_type': 'exp',
113                'alpha_phase': 2.5
114            },
115            'controller_kwargs': {
116                'controller_type': 'velocity'
117            },
118            'basis_generator_kwargs': {
119                'basis_generator_type': 'rbf',
120                'num_basis': 5
121            }
122        }
123    }
124
125
126def example_fully_custom_mp(seed=1, iterations=1, render=True):
127    """
128    Example for running a custom movement primitive based environments.
129    Our already registered environments follow the same structure.
130    Hence, this also allows to adjust hyperparameters of the movement primitives.
131    Yet, we recommend the method above if you are just interested in changing those parameters for existing tasks.
132    We appreciate PRs for custom environments (especially MP wrappers of existing tasks) 
133    for our repo: https://github.com/ALRhub/fancy_gym/
134    Args:
135        seed: seed
136        iterations: Number of rollout steps to run
137        render: Render the episode
138
139    Returns:
140
141    """
142
143    base_env_id = "fancy/Reacher5d-v0"
144    custom_env_id = "fancy/Reacher5d-Custom-v0"
145    custom_env_id_DMP = "fancy_DMP/Reacher5d-Custom-v0"
146    custom_env_id_ProMP = "fancy_ProMP/Reacher5d-Custom-v0"
147
148    fancy_gym.upgrade(custom_env_id, mp_wrapper=Custom_MPWrapper, add_mp_types=['ProMP', 'DMP'], base_id=base_env_id)
149
150    env = gym.make(custom_env_id_ProMP, render_mode='human' if render else None)
151
152    rewards = 0
153    obs = env.reset()
154
155    if render:
156        env.render()
157
158    # number of samples/full trajectories (multiple environment steps)
159    for i in range(iterations):
160        ac = env.action_space.sample()
161        obs, reward, terminated, truncated, info = env.step(ac)
162        rewards += reward
163
164        if terminated or truncated:
165            print(rewards)
166            rewards = 0
167            obs = env.reset()
168
169    try: # Some mujoco-based envs don't correlcty implement .close
170        env.close()
171    except:
172        pass
173
174
175def example_fully_custom_mp_alternative(seed=1, iterations=1, render=True):
176    """
177    Instead of defining the mp_args in a new custom MP_Wrapper, they can also be provided during registration.
178    Args:
179        seed: seed
180        iterations: Number of rollout steps to run
181        render: Render the episode
182
183    Returns:
184
185    """
186
187    base_env_id = "fancy/Reacher5d-v0"
188    custom_env_id = "fancy/Reacher5d-Custom-v0"
189    custom_env_id_ProMP = "fancy_ProMP/Reacher5d-Custom-v0"
190
191    fancy_gym.upgrade(custom_env_id, mp_wrapper=fancy_gym.envs.mujoco.reacher.MPWrapper, add_mp_types=['ProMP'], base_id=base_env_id, mp_config_override=     {'ProMP': {
192                'trajectory_generator_kwargs':  {
193                    'trajectory_generator_type': 'promp',
194                    'weights_scale': 2
195                },
196                'phase_generator_kwargs': {
197                    'phase_generator_type': 'linear'
198                },
199                'controller_kwargs': {
200                    'controller_type': 'velocity'
201                },
202                'basis_generator_kwargs': {
203                    'basis_generator_type': 'zero_rbf',
204                    'num_basis': 5,
205                    'num_basis_zero_start': 1
206                }
207        }})
208
209    env = gym.make(custom_env_id_ProMP, render_mode='human' if render else None)
210
211    rewards = 0
212    obs = env.reset()
213
214    if render:
215        env.render()
216
217    # number of samples/full trajectories (multiple environment steps)
218    for i in range(iterations):
219        ac = env.action_space.sample()
220        obs, reward, terminated, truncated, info = env.step(ac)
221        rewards += reward
222
223        if terminated or truncated:
224            print(rewards)
225            rewards = 0
226            obs = env.reset()
227
228    if render:
229        env.render()
230
231    rewards = 0
232    obs = env.reset()
233
234    # number of samples/full trajectories (multiple environment steps)
235    for i in range(iterations):
236        ac = env.action_space.sample()
237        obs, reward, terminated, truncated, info = env.step(ac)
238        rewards += reward
239
240        if terminated or truncated:
241            print(rewards)
242            rewards = 0
243            obs = env.reset()
244
245    try: # Some mujoco-based envs don't correlcty implement .close
246        env.close()
247    except:
248        pass
249
250
251def main():
252    render = False
253    # DMP
254    example_mp("fancy_DMP/HoleReacher-v0", seed=10, iterations=5, render=render)
255
256    # ProMP
257    example_mp("fancy_ProMP/HoleReacher-v0", seed=10, iterations=5, render=render)
258    example_mp("fancy_ProMP/BoxPushingTemporalSparse-v0", seed=10, iterations=1, render=render)
259    example_mp("fancy_ProMP/TableTennis4D-v0", seed=10, iterations=20, render=render)
260
261    # ProDMP with Replanning
262    example_mp("fancy_ProDMP/BoxPushingDenseReplan-v0", seed=10, iterations=4, render=render)
263    example_mp("fancy_ProDMP/TableTennis4DReplan-v0", seed=10, iterations=20, render=render)
264    example_mp("fancy_ProDMP/TableTennisWindReplan-v0", seed=10, iterations=20, render=render)
265
266    # Altered basis functions
267    obs1 = example_custom_mp("fancy_ProMP/Reacher5d-v0", seed=10, iterations=1, render=render)
268
269    # Custom MP
270    example_fully_custom_mp(seed=10, iterations=1, render=render)
271    example_fully_custom_mp_alternative(seed=10, iterations=1, render=render)
272
273if __name__=='__main__':
274    main()