diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py
index 0181f62..026f1a7 100644
--- a/alr_envs/alr/__init__.py
+++ b/alr_envs/alr/__init__.py
@@ -236,6 +236,17 @@ register(
}
)
+# Beerpong devel big table
+register(
+ id='ALRBeerPong-v3',
+ entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
+ max_episode_steps=600,
+ kwargs={
+ "rndm_goal": True,
+ "cup_goal_pos": [-0.3, -1.2]
+ }
+ )
+
# Motion Primitive Environments
## Simple Reacher
@@ -402,6 +413,32 @@ for _v in _versions:
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
+## Beerpong- Big table devel
+
+register(
+ id='BeerpongProMP-v3',
+ entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
+ kwargs={
+ "name": f"alr_envs:ALRBeerPong-v3",
+ "wrappers": [mujoco.beerpong.MPWrapper],
+ "mp_kwargs": {
+ "num_dof": 7,
+ "num_basis": 5,
+ "duration": 1,
+ "post_traj_time": 2,
+ "policy_type": "motor",
+ "weights_scale": 1,
+ "zero_start": True,
+ "zero_goal": False,
+ "policy_kwargs": {
+ "p_gains": np.array([ 1.5, 5, 2.55, 3, 2., 2, 1.25]),
+ "d_gains": np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
+ }
+ }
+ }
+ )
+ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append('BeerpongProMP-v3')
+
## Table Tennis
ctxt_dim = [2, 4]
for _v, cd in enumerate(ctxt_dim):
@@ -416,7 +453,7 @@ for _v, cd in enumerate(ctxt_dim):
"num_dof": 7,
"num_basis": 2,
"duration": 1.25,
- "post_traj_time": 4.5,
+ "post_traj_time": 1.5,
"policy_type": "motor",
"weights_scale": 1.0,
"zero_start": True,
diff --git a/alr_envs/alr/mujoco/beerpong/assets/beerpong_wo_cup.xml b/alr_envs/alr/mujoco/beerpong/assets/beerpong_wo_cup.xml
index e96d2bc..436b36c 100644
--- a/alr_envs/alr/mujoco/beerpong/assets/beerpong_wo_cup.xml
+++ b/alr_envs/alr/mujoco/beerpong/assets/beerpong_wo_cup.xml
@@ -132,18 +132,19 @@
-
+
-
+
+
diff --git a/alr_envs/alr/mujoco/beerpong/beerpong.py b/alr_envs/alr/mujoco/beerpong/beerpong.py
index 9092ef1..99d1a23 100644
--- a/alr_envs/alr/mujoco/beerpong/beerpong.py
+++ b/alr_envs/alr/mujoco/beerpong/beerpong.py
@@ -7,8 +7,11 @@ from gym.envs.mujoco import MujocoEnv
from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
-CUP_POS_MIN = np.array([-0.32, -2.2])
-CUP_POS_MAX = np.array([0.32, -1.2])
+# CUP_POS_MIN = np.array([-0.32, -2.2])
+# CUP_POS_MAX = np.array([0.32, -1.2])
+
+CUP_POS_MIN = np.array([-1.42, -4.05])
+CUP_POS_MAX = np.array([1.42, -1.25])
class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
diff --git a/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py b/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py
index 86d5a00..473583f 100644
--- a/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py
+++ b/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py
@@ -11,9 +11,9 @@ class MPWrapper(MPEnvWrapper):
def active_obs(self):
# TODO: @Max Filter observations correctly
return np.hstack([
- [True] * 7, # Joint Pos
- [True] * 3, # Ball pos
- [True] * 3 # goal pos
+ [False] * 7, # Joint Pos
+ [True] * 2, # Ball pos
+ [True] * 2 # goal pos
])
@property
diff --git a/alr_envs/alr/mujoco/table_tennis/tt_gym.py b/alr_envs/alr/mujoco/table_tennis/tt_gym.py
index d1c2dc3..e88bbc5 100644
--- a/alr_envs/alr/mujoco/table_tennis/tt_gym.py
+++ b/alr_envs/alr/mujoco/table_tennis/tt_gym.py
@@ -10,7 +10,8 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
-MAX_EPISODE_STEPS = 1750
+# MAX_EPISODE_STEPS = 1750
+MAX_EPISODE_STEPS = 1375
BALL_NAME_CONTACT = "target_ball_contact"
BALL_NAME = "target_ball"
TABLE_NAME = 'table_tennis_table'
@@ -76,10 +77,11 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
self._ids_set = True
def _get_obs(self):
- ball_pos = self.sim.data.body_xpos[self.ball_id]
+ ball_pos = self.sim.data.body_xpos[self.ball_id][:2].copy()
+ goal_pos = self.goal[:2].copy()
obs = np.concatenate([self.sim.data.qpos[:7].copy(), # 7 joint positions
ball_pos,
- self.goal.copy()])
+ goal_pos])
return obs
def sample_context(self):
diff --git a/setup.py b/setup.py
index 6122e90..796c569 100644
--- a/setup.py
+++ b/setup.py
@@ -7,7 +7,7 @@ setup(
install_requires=[
'gym',
'PyQt5',
- 'matplotlib',
+ #'matplotlib',
#'mp_env_api @ git+https://github.com/ALRhub/motion_primitive_env_api.git',
# 'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git',
'mujoco-py<2.1,>=2.0',