randomized cup
This commit is contained in:
		
							parent
							
								
									e1617c34a6
								
							
						
					
					
						commit
						be5b287ae1
					
				@ -213,16 +213,26 @@ register(id='TableTennis4DCtxt-v0',
 | 
				
			|||||||
         kwargs={'ctxt_dim': 4})
 | 
					         kwargs={'ctxt_dim': 4})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## BeerPong
 | 
					## BeerPong
 | 
				
			||||||
difficulties = ["simple", "intermediate", "hard", "hardest"]
 | 
					# fixed goal cup position
 | 
				
			||||||
 | 
					register(
 | 
				
			||||||
for v, difficulty in enumerate(difficulties):
 | 
					        id='ALRBeerPong-v0',
 | 
				
			||||||
    register(
 | 
					 | 
				
			||||||
        id='ALRBeerPong-v{}'.format(v),
 | 
					 | 
				
			||||||
        entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
 | 
					        entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
 | 
				
			||||||
        max_episode_steps=600,
 | 
					        max_episode_steps=600,
 | 
				
			||||||
        kwargs={
 | 
					        kwargs={
 | 
				
			||||||
            "difficulty": difficulty,
 | 
					            "rndm_goal": False,
 | 
				
			||||||
            "reward_type": "staged",
 | 
					            "cup_goal_pos": [-0.3, -1.2]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# random goal cup position
 | 
				
			||||||
 | 
					register(
 | 
				
			||||||
 | 
					        id='ALRBeerPong-v1',
 | 
				
			||||||
 | 
					        entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
 | 
				
			||||||
 | 
					        max_episode_steps=600,
 | 
				
			||||||
 | 
					        kwargs={
 | 
				
			||||||
 | 
					            "rndm_goal": True,
 | 
				
			||||||
 | 
					            "cup_goal_pos": [-0.3, -1.2]
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -4,11 +4,16 @@ import os
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from gym import utils
 | 
					from gym import utils
 | 
				
			||||||
from gym.envs.mujoco import MujocoEnv
 | 
					from gym.envs.mujoco import MujocoEnv
 | 
				
			||||||
 | 
					from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CUP_POS_MIN = np.array([-0.32, -2.2])
 | 
				
			||||||
 | 
					CUP_POS_MAX = np.array([0.32, -1.2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
 | 
					class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
 | 
				
			||||||
    def __init__(self, frame_skip=1, apply_gravity_comp=True, reward_type: str = "staged", noisy=False,
 | 
					    def __init__(self, frame_skip=1, apply_gravity_comp=True, noisy=False,
 | 
				
			||||||
                 context: np.ndarray = None, difficulty='simple'):
 | 
					                 rndm_goal=False, cup_goal_pos=[-0.3, -1.2]):
 | 
				
			||||||
        self._steps = 0
 | 
					        self._steps = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
 | 
					        self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
 | 
				
			||||||
@ -17,7 +22,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
 | 
				
			|||||||
        self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
 | 
					        self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
 | 
				
			||||||
        self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
 | 
					        self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.context = context
 | 
					        self.rndm_goal = rndm_goal
 | 
				
			||||||
        self.apply_gravity_comp = apply_gravity_comp
 | 
					        self.apply_gravity_comp = apply_gravity_comp
 | 
				
			||||||
        self.add_noise = noisy
 | 
					        self.add_noise = noisy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -38,23 +43,8 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.noise_std = 0
 | 
					            self.noise_std = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if difficulty == 'simple':
 | 
					        self.cup_goal_pos = np.array(cup_goal_pos.append(0.840))
 | 
				
			||||||
            self.cup_goal_pos = np.array([0, -1.7, 0.840])
 | 
					        reward_function = BeerPongReward
 | 
				
			||||||
        elif difficulty == 'intermediate':
 | 
					 | 
				
			||||||
            self.cup_goal_pos = np.array([0.3, -1.5, 0.840])
 | 
					 | 
				
			||||||
        elif difficulty == 'hard':
 | 
					 | 
				
			||||||
            self.cup_goal_pos = np.array([-0.3, -2.2, 0.840])
 | 
					 | 
				
			||||||
        elif difficulty == 'hardest':
 | 
					 | 
				
			||||||
            self.cup_goal_pos = np.array([-0.3, -1.2, 0.840])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if reward_type == "no_context":
 | 
					 | 
				
			||||||
            from alr_envs.alr.mujoco.beerpong.beerpong_reward import BeerPongReward
 | 
					 | 
				
			||||||
            reward_function = BeerPongReward
 | 
					 | 
				
			||||||
        elif reward_type == "staged":
 | 
					 | 
				
			||||||
            from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
 | 
					 | 
				
			||||||
            reward_function = BeerPongReward
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise ValueError("Unknown reward type: {}".format(reward_type))
 | 
					 | 
				
			||||||
        self.reward_function = reward_function()
 | 
					        self.reward_function = reward_function()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        MujocoEnv.__init__(self, self.xml_path, frame_skip)
 | 
					        MujocoEnv.__init__(self, self.xml_path, frame_skip)
 | 
				
			||||||
@ -94,6 +84,12 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
 | 
				
			|||||||
        self.sim.model.body_pos[self.cup_table_id] = self.cup_goal_pos
 | 
					        self.sim.model.body_pos[self.cup_table_id] = self.cup_goal_pos
 | 
				
			||||||
        start_pos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
 | 
					        start_pos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
 | 
				
			||||||
        self.set_state(start_pos, init_vel)
 | 
					        self.set_state(start_pos, init_vel)
 | 
				
			||||||
 | 
					        if self.rndm_goal:
 | 
				
			||||||
 | 
					            xy = np.random.uniform(CUP_POS_MIN, CUP_POS_MAX)
 | 
				
			||||||
 | 
					            xyz = np.zeros(3)
 | 
				
			||||||
 | 
					            xyz[:2] = xy
 | 
				
			||||||
 | 
					            xyz[-1] = 0.840
 | 
				
			||||||
 | 
					            self.sim.model.body_pos[self.cup_table_id] = xyz
 | 
				
			||||||
        return self._get_obs()
 | 
					        return self._get_obs()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def step(self, a):
 | 
					    def step(self, a):
 | 
				
			||||||
@ -153,13 +149,12 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
 | 
				
			|||||||
    def check_traj_in_joint_limits(self):
 | 
					    def check_traj_in_joint_limits(self):
 | 
				
			||||||
        return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
 | 
					        return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO: extend observation space
 | 
					 | 
				
			||||||
    def _get_obs(self):
 | 
					    def _get_obs(self):
 | 
				
			||||||
        theta = self.sim.data.qpos.flat[:7]
 | 
					        theta = self.sim.data.qpos.flat[:7]
 | 
				
			||||||
        return np.concatenate([
 | 
					        return np.concatenate([
 | 
				
			||||||
            np.cos(theta),
 | 
					            np.cos(theta),
 | 
				
			||||||
            np.sin(theta),
 | 
					            np.sin(theta),
 | 
				
			||||||
            # self.get_body_com("target"),  # only return target to make problem harder
 | 
					            self.sim.model.body_pos[self.cup_table_id][:2].copy(),
 | 
				
			||||||
            [self._steps],
 | 
					            [self._steps],
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -169,25 +164,26 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
 | 
				
			|||||||
        return np.hstack([
 | 
					        return np.hstack([
 | 
				
			||||||
            [False] * 7,  # cos
 | 
					            [False] * 7,  # cos
 | 
				
			||||||
            [False] * 7,  # sin
 | 
					            [False] * 7,  # sin
 | 
				
			||||||
            # [True] * 2,  # x-y coordinates of target distance
 | 
					            [True] * 2,  # xy position of cup
 | 
				
			||||||
            [False]  # env steps
 | 
					            [False]  # env steps
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    env = ALRBeerBongEnv(reward_type="staged", difficulty='hardest')
 | 
					    env = ALRBeerBongEnv(rndm_goal=True)
 | 
				
			||||||
 | 
					    import time
 | 
				
			||||||
    # env.configure(ctxt)
 | 
					 | 
				
			||||||
    env.reset()
 | 
					    env.reset()
 | 
				
			||||||
    env.render("human")
 | 
					    env.render("human")
 | 
				
			||||||
    for i in range(800):
 | 
					    for i in range(1500):
 | 
				
			||||||
        ac = 10 * env.action_space.sample()[0:7]
 | 
					        # ac = 10 * env.action_space.sample()[0:7]
 | 
				
			||||||
 | 
					        ac = np.zeros(7)
 | 
				
			||||||
        obs, rew, d, info = env.step(ac)
 | 
					        obs, rew, d, info = env.step(ac)
 | 
				
			||||||
        env.render("human")
 | 
					        env.render("human")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        print(rew)
 | 
					        print(rew)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if d:
 | 
					        if d:
 | 
				
			||||||
            break
 | 
					            print('RESETTING')
 | 
				
			||||||
 | 
					            env.reset()
 | 
				
			||||||
 | 
					            time.sleep(1)
 | 
				
			||||||
    env.close()
 | 
					    env.close()
 | 
				
			||||||
 | 
				
			|||||||
@ -9,11 +9,10 @@ class MPWrapper(MPEnvWrapper):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def active_obs(self):
 | 
					    def active_obs(self):
 | 
				
			||||||
        # TODO: @Max Filter observations correctly
 | 
					 | 
				
			||||||
        return np.hstack([
 | 
					        return np.hstack([
 | 
				
			||||||
            [False] * 7,  # cos
 | 
					            [False] * 7,  # cos
 | 
				
			||||||
            [False] * 7,  # sin
 | 
					            [False] * 7,  # sin
 | 
				
			||||||
            # [True] * 2,  # x-y coordinates of target distance
 | 
					            [True] * 2,  # xy position of cup
 | 
				
			||||||
            [False]  # env steps
 | 
					            [False]  # env steps
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -31,7 +30,6 @@ class MPWrapper(MPEnvWrapper):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def goal_pos(self):
 | 
					    def goal_pos(self):
 | 
				
			||||||
        # TODO: @Max I think the default value of returning to the start is reasonable here
 | 
					 | 
				
			||||||
        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
					        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user