randomized cup
This commit is contained in:
parent
e1617c34a6
commit
be5b287ae1
@ -213,16 +213,26 @@ register(id='TableTennis4DCtxt-v0',
|
|||||||
kwargs={'ctxt_dim': 4})
|
kwargs={'ctxt_dim': 4})
|
||||||
|
|
||||||
## BeerPong
|
## BeerPong
|
||||||
difficulties = ["simple", "intermediate", "hard", "hardest"]
|
# fixed goal cup position
|
||||||
|
register(
|
||||||
for v, difficulty in enumerate(difficulties):
|
id='ALRBeerPong-v0',
|
||||||
register(
|
|
||||||
id='ALRBeerPong-v{}'.format(v),
|
|
||||||
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
||||||
max_episode_steps=600,
|
max_episode_steps=600,
|
||||||
kwargs={
|
kwargs={
|
||||||
"difficulty": difficulty,
|
"rndm_goal": False,
|
||||||
"reward_type": "staged",
|
"cup_goal_pos": [-0.3, -1.2]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# random goal cup position
|
||||||
|
register(
|
||||||
|
id='ALRBeerPong-v1',
|
||||||
|
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
||||||
|
max_episode_steps=600,
|
||||||
|
kwargs={
|
||||||
|
"rndm_goal": True,
|
||||||
|
"cup_goal_pos": [-0.3, -1.2]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,11 +4,16 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.mujoco import MujocoEnv
|
from gym.envs.mujoco import MujocoEnv
|
||||||
|
from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
|
||||||
|
|
||||||
|
|
||||||
|
CUP_POS_MIN = np.array([-0.32, -2.2])
|
||||||
|
CUP_POS_MAX = np.array([0.32, -1.2])
|
||||||
|
|
||||||
|
|
||||||
class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||||
def __init__(self, frame_skip=1, apply_gravity_comp=True, reward_type: str = "staged", noisy=False,
|
def __init__(self, frame_skip=1, apply_gravity_comp=True, noisy=False,
|
||||||
context: np.ndarray = None, difficulty='simple'):
|
rndm_goal=False, cup_goal_pos=[-0.3, -1.2]):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||||
@ -17,7 +22,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
|
self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
|
||||||
self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
|
self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
|
||||||
|
|
||||||
self.context = context
|
self.rndm_goal = rndm_goal
|
||||||
self.apply_gravity_comp = apply_gravity_comp
|
self.apply_gravity_comp = apply_gravity_comp
|
||||||
self.add_noise = noisy
|
self.add_noise = noisy
|
||||||
|
|
||||||
@ -38,23 +43,8 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
else:
|
else:
|
||||||
self.noise_std = 0
|
self.noise_std = 0
|
||||||
|
|
||||||
if difficulty == 'simple':
|
self.cup_goal_pos = np.array(cup_goal_pos.append(0.840))
|
||||||
self.cup_goal_pos = np.array([0, -1.7, 0.840])
|
reward_function = BeerPongReward
|
||||||
elif difficulty == 'intermediate':
|
|
||||||
self.cup_goal_pos = np.array([0.3, -1.5, 0.840])
|
|
||||||
elif difficulty == 'hard':
|
|
||||||
self.cup_goal_pos = np.array([-0.3, -2.2, 0.840])
|
|
||||||
elif difficulty == 'hardest':
|
|
||||||
self.cup_goal_pos = np.array([-0.3, -1.2, 0.840])
|
|
||||||
|
|
||||||
if reward_type == "no_context":
|
|
||||||
from alr_envs.alr.mujoco.beerpong.beerpong_reward import BeerPongReward
|
|
||||||
reward_function = BeerPongReward
|
|
||||||
elif reward_type == "staged":
|
|
||||||
from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
|
|
||||||
reward_function = BeerPongReward
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown reward type: {}".format(reward_type))
|
|
||||||
self.reward_function = reward_function()
|
self.reward_function = reward_function()
|
||||||
|
|
||||||
MujocoEnv.__init__(self, self.xml_path, frame_skip)
|
MujocoEnv.__init__(self, self.xml_path, frame_skip)
|
||||||
@ -94,6 +84,12 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.sim.model.body_pos[self.cup_table_id] = self.cup_goal_pos
|
self.sim.model.body_pos[self.cup_table_id] = self.cup_goal_pos
|
||||||
start_pos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
|
start_pos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
|
||||||
self.set_state(start_pos, init_vel)
|
self.set_state(start_pos, init_vel)
|
||||||
|
if self.rndm_goal:
|
||||||
|
xy = np.random.uniform(CUP_POS_MIN, CUP_POS_MAX)
|
||||||
|
xyz = np.zeros(3)
|
||||||
|
xyz[:2] = xy
|
||||||
|
xyz[-1] = 0.840
|
||||||
|
self.sim.model.body_pos[self.cup_table_id] = xyz
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
@ -153,13 +149,12 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
def check_traj_in_joint_limits(self):
|
def check_traj_in_joint_limits(self):
|
||||||
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
|
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
|
||||||
|
|
||||||
# TODO: extend observation space
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
theta = self.sim.data.qpos.flat[:7]
|
theta = self.sim.data.qpos.flat[:7]
|
||||||
return np.concatenate([
|
return np.concatenate([
|
||||||
np.cos(theta),
|
np.cos(theta),
|
||||||
np.sin(theta),
|
np.sin(theta),
|
||||||
# self.get_body_com("target"), # only return target to make problem harder
|
self.sim.model.body_pos[self.cup_table_id][:2].copy(),
|
||||||
[self._steps],
|
[self._steps],
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -169,25 +164,26 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * 7, # cos
|
[False] * 7, # cos
|
||||||
[False] * 7, # sin
|
[False] * 7, # sin
|
||||||
# [True] * 2, # x-y coordinates of target distance
|
[True] * 2, # xy position of cup
|
||||||
[False] # env steps
|
[False] # env steps
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = ALRBeerBongEnv(reward_type="staged", difficulty='hardest')
|
env = ALRBeerBongEnv(rndm_goal=True)
|
||||||
|
import time
|
||||||
# env.configure(ctxt)
|
|
||||||
env.reset()
|
env.reset()
|
||||||
env.render("human")
|
env.render("human")
|
||||||
for i in range(800):
|
for i in range(1500):
|
||||||
ac = 10 * env.action_space.sample()[0:7]
|
# ac = 10 * env.action_space.sample()[0:7]
|
||||||
|
ac = np.zeros(7)
|
||||||
obs, rew, d, info = env.step(ac)
|
obs, rew, d, info = env.step(ac)
|
||||||
env.render("human")
|
env.render("human")
|
||||||
|
|
||||||
print(rew)
|
print(rew)
|
||||||
|
|
||||||
if d:
|
if d:
|
||||||
break
|
print('RESETTING')
|
||||||
|
env.reset()
|
||||||
|
time.sleep(1)
|
||||||
env.close()
|
env.close()
|
||||||
|
@ -9,11 +9,10 @@ class MPWrapper(MPEnvWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
# TODO: @Max Filter observations correctly
|
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * 7, # cos
|
[False] * 7, # cos
|
||||||
[False] * 7, # sin
|
[False] * 7, # sin
|
||||||
# [True] * 2, # x-y coordinates of target distance
|
[True] * 2, # xy position of cup
|
||||||
[False] # env steps
|
[False] # env steps
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -31,7 +30,6 @@ class MPWrapper(MPEnvWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def goal_pos(self):
|
def goal_pos(self):
|
||||||
# TODO: @Max I think the default value of returning to the start is reasonable here
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user