mp wrappers added
This commit is contained in:
		
							parent
							
								
									9b2c330ebf
								
							
						
					
					
						commit
						78d48c4300
					
				
							
								
								
									
										23
									
								
								alr_envs/alr/mujoco/ant_jump/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								alr_envs/alr/mujoco/ant_jump/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,23 @@
 | 
				
			|||||||
 | 
					from typing import Union, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MPWrapper(RawInterfaceWrapper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def context_mask(self):
 | 
				
			||||||
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
					            [False] * 111,  # ant has 111 dimensional observation space !!
 | 
				
			||||||
 | 
					            [True]  # goal height
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def current_pos(self) -> Union[float, int, np.ndarray]:
 | 
				
			||||||
 | 
					        return self.env.sim.data.qpos[7:15].copy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
				
			||||||
 | 
					        return self.env.sim.data.qvel[6:14].copy()
 | 
				
			||||||
							
								
								
									
										23
									
								
								alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,23 @@
 | 
				
			|||||||
 | 
					from typing import Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MPWrapper(RawInterfaceWrapper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def context_mask(self):
 | 
				
			||||||
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
					            [False] * 17,
 | 
				
			||||||
 | 
					            [True]  # goal pos
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def current_pos(self) -> Union[float, int, np.ndarray]:
 | 
				
			||||||
 | 
					        return self.env.data.qpos[3:9].copy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
				
			||||||
 | 
					        return self.env.data.qvel[3:9].copy()
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user