mp wrapper beer pong
This commit is contained in:
		
							parent
							
								
									3a989e179b
								
							
						
					
					
						commit
						9b2c330ebf
					
				
							
								
								
									
										42
									
								
								alr_envs/alr/mujoco/beerpong/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								alr_envs/alr/mujoco/beerpong/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					from typing import Union, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MPWrapper(RawInterfaceWrapper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_context_mask(self):
 | 
				
			||||||
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
					            [False] * 7,  # cos
 | 
				
			||||||
 | 
					            [False] * 7,  # sin
 | 
				
			||||||
 | 
					            [False] * 7,  # joint velocities
 | 
				
			||||||
 | 
					            [False] * 3,  # cup_goal_diff_final
 | 
				
			||||||
 | 
					            [False] * 3,  # cup_goal_diff_top
 | 
				
			||||||
 | 
					            [True] * 2,  # xy position of cup
 | 
				
			||||||
 | 
					            [False]  # env steps
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
				
			||||||
 | 
					        return self.env.sim.data.qpos[0:7].copy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
				
			||||||
 | 
					        return self.env.sim.data.qvel[0:7].copy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # TODO: Fix this
 | 
				
			||||||
 | 
					    def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
 | 
				
			||||||
 | 
					        if mp.learn_tau:
 | 
				
			||||||
 | 
					            self.env.env.release_step = action[0] / self.env.dt  # Tau value
 | 
				
			||||||
 | 
					            return action, None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return action, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_context(self, context):
 | 
				
			||||||
 | 
					        xyz = np.zeros(3)
 | 
				
			||||||
 | 
					        xyz[:2] = context
 | 
				
			||||||
 | 
					        xyz[-1] = 0.840
 | 
				
			||||||
 | 
					        self.env.env.model.body_pos[self.env.env.cup_table_id] = xyz
 | 
				
			||||||
 | 
					        return self.get_observation_from_step(self.env.env._get_obs())
 | 
				
			||||||
							
								
								
									
										10
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								setup.py
									
									
									
									
									
								
							@ -5,7 +5,7 @@ from setuptools import setup
 | 
				
			|||||||
# Environment-specific dependencies for dmc and metaworld
 | 
					# Environment-specific dependencies for dmc and metaworld
 | 
				
			||||||
extras = {
 | 
					extras = {
 | 
				
			||||||
    "dmc": ["dm_control"],
 | 
					    "dmc": ["dm_control"],
 | 
				
			||||||
    "meta": ["mujoco_py<2.2,>=2.1, git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"],
 | 
					    "meta": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"],
 | 
				
			||||||
    "mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
 | 
					    "mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -20,13 +20,7 @@ setup(
 | 
				
			|||||||
    packages=['alr_envs', 'alr_envs.alr', 'alr_envs.open_ai', 'alr_envs.dmc', 'alr_envs.meta', 'alr_envs.utils'],
 | 
					    packages=['alr_envs', 'alr_envs.alr', 'alr_envs.open_ai', 'alr_envs.dmc', 'alr_envs.meta', 'alr_envs.utils'],
 | 
				
			||||||
    install_requires=[
 | 
					    install_requires=[
 | 
				
			||||||
        'gym',
 | 
					        'gym',
 | 
				
			||||||
        'PyQt5',
 | 
					        "mujoco_py<2.2,>=2.1",
 | 
				
			||||||
        # 'matplotlib',
 | 
					 | 
				
			||||||
        # 'mp_env_api @ git+https://github.com/ALRhub/motion_primitive_env_api.git',
 | 
					 | 
				
			||||||
        #         'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git',
 | 
					 | 
				
			||||||
        'mujoco-py<2.1,>=2.0',
 | 
					 | 
				
			||||||
        'dm_control',
 | 
					 | 
				
			||||||
        'metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld',
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    url='https://github.com/ALRhub/alr_envs/',
 | 
					    url='https://github.com/ALRhub/alr_envs/',
 | 
				
			||||||
    # license='AGPL-3.0 license',
 | 
					    # license='AGPL-3.0 license',
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user