Implement AirHocKIT 2023 environments
This commit is contained in:
parent
dbd3caebb3
commit
6afb5880db
@ -291,7 +291,7 @@ register(
|
||||
)
|
||||
|
||||
# Air Hockey environments
|
||||
for env_mode in ["7dof-hit", "7dof-defend", "3dof-hit", "3dof-defend"]:
|
||||
for env_mode in ["7dof-hit", "7dof-defend", "3dof-hit", "3dof-defend", "7dof-hit-airhockit2023", "7dof-defend-airhockit2023"]:
|
||||
register(
|
||||
id=f'fancy/AirHockey-{env_mode}-v0',
|
||||
entry_point='fancy_gym.envs.mujoco:AirHockeyEnv',
|
||||
|
@ -30,7 +30,10 @@ class AirHockeyEnv(Environment):
|
||||
"7dof-defend": position.IiwaPositionDefend,
|
||||
|
||||
"3dof-hit": position.PlanarPositionHit,
|
||||
"3dof-defend": position.PlanarPositionDefend
|
||||
"3dof-defend": position.PlanarPositionDefend,
|
||||
|
||||
"7dof-hit-airhockit2023": position.IiwaPositionHitAirhocKIT2023,
|
||||
"7dof-defend-airhockit2023": position.IiwaPositionDefendAirhocKIT2023,
|
||||
}
|
||||
|
||||
if env_mode not in env_dict:
|
||||
@ -42,34 +45,39 @@ class AirHockeyEnv(Environment):
|
||||
self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs)
|
||||
self.env_name = env_mode
|
||||
self.env_info = self.base_env.env_info
|
||||
single_robot_obs_size = len(self.base_env.info.observation_space.low)
|
||||
if env_mode == "tournament":
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,single_robot_obs_size), dtype=np.float64)
|
||||
else:
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(single_robot_obs_size,), dtype=np.float64)
|
||||
robot_info = self.env_info["robot"]
|
||||
|
||||
if env_mode != "tournament":
|
||||
if interpolation_order in [1, 2]:
|
||||
self.action_space = spaces.Box(low=robot_info["joint_pos_limit"][0], high=robot_info["joint_pos_limit"][1])
|
||||
if interpolation_order in [3, 4, -1]:
|
||||
self.action_space = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1]]))
|
||||
if interpolation_order in [5]:
|
||||
self.action_space = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0], robot_info["joint_acc_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1], robot_info["joint_acc_limit"][1]]))
|
||||
if hasattr(self.base_env, "wrapper_obs_space") and hasattr(self.base_env, "wrapper_act_space"):
|
||||
self.observation_space = self.base_env.wrapper_obs_space
|
||||
self.action_space = self.base_env.wrapper_act_space
|
||||
else:
|
||||
acts = [None, None]
|
||||
for i in range(2):
|
||||
if interpolation_order[i] in [1, 2]:
|
||||
acts[i] = spaces.Box(low=robot_info["joint_pos_limit"][0], high=robot_info["joint_pos_limit"][1])
|
||||
if interpolation_order[i] in [3, 4, -1]:
|
||||
acts[i] = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1]]))
|
||||
if interpolation_order[i] in [5]:
|
||||
acts[i] = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0], robot_info["joint_acc_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1], robot_info["joint_acc_limit"][1]]))
|
||||
self.action_space = spaces.Tuple((acts[0], acts[1]))
|
||||
single_robot_obs_size = len(self.base_env.info.observation_space.low)
|
||||
if env_mode == "tournament":
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,single_robot_obs_size), dtype=np.float64)
|
||||
else:
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(single_robot_obs_size,), dtype=np.float64)
|
||||
robot_info = self.env_info["robot"]
|
||||
|
||||
if env_mode != "tournament":
|
||||
if interpolation_order in [1, 2]:
|
||||
self.action_space = spaces.Box(low=robot_info["joint_pos_limit"][0], high=robot_info["joint_pos_limit"][1])
|
||||
if interpolation_order in [3, 4, -1]:
|
||||
self.action_space = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1]]))
|
||||
if interpolation_order in [5]:
|
||||
self.action_space = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0], robot_info["joint_acc_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1], robot_info["joint_acc_limit"][1]]))
|
||||
else:
|
||||
acts = [None, None]
|
||||
for i in range(2):
|
||||
if interpolation_order[i] in [1, 2]:
|
||||
acts[i] = spaces.Box(low=robot_info["joint_pos_limit"][0], high=robot_info["joint_pos_limit"][1])
|
||||
if interpolation_order[i] in [3, 4, -1]:
|
||||
acts[i] = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1]]))
|
||||
if interpolation_order[i] in [5]:
|
||||
acts[i] = spaces.Box(low=np.vstack([robot_info["joint_pos_limit"][0], robot_info["joint_vel_limit"][0], robot_info["joint_acc_limit"][0]]),
|
||||
high=np.vstack([robot_info["joint_pos_limit"][1], robot_info["joint_vel_limit"][1], robot_info["joint_acc_limit"][1]]))
|
||||
self.action_space = spaces.Tuple((acts[0], acts[1]))
|
||||
|
||||
constraint_list = constraints.ConstraintList()
|
||||
constraint_list.add(constraints.JointPositionConstraint(self.env_info))
|
||||
|
@ -261,10 +261,14 @@ class PlanarPositionDefend(PositionControlPlanar, three_dof.AirHockeyDefend):
|
||||
class IiwaPositionHit(PositionControlIIWA, seven_dof.AirHockeyHit):
|
||||
pass
|
||||
|
||||
class IiwaPositionHitAirhocKIT2023(PositionControlIIWA, seven_dof.AirHockeyHitAirhocKIT2023):
|
||||
pass
|
||||
|
||||
class IiwaPositionDefend(PositionControlIIWA, seven_dof.AirHockeyDefend):
|
||||
pass
|
||||
|
||||
class IiwaPositionDefendAirhocKIT2023(PositionControlIIWA, seven_dof.AirHockeyDefendAirhocKIT2023):
|
||||
pass
|
||||
|
||||
class IiwaPositionTournament(PositionControlIIWA, seven_dof.AirHockeyTournament):
|
||||
pass
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .env_base import AirHockeyBase
|
||||
from .tournament import AirHockeyTournament
|
||||
from .hit import AirHockeyHit
|
||||
from .defend import AirHockeyDefend
|
||||
from .hit import AirHockeyHit, AirHockeyHitAirhocKIT2023
|
||||
from .defend import AirHockeyDefend, AirHockeyDefendAirhocKIT2023
|
105
fancy_gym/envs/mujoco/air_hockey/seven_dof/airhockit_base_env.py
Normal file
105
fancy_gym/envs/mujoco/air_hockey/seven_dof/airhockit_base_env.py
Normal file
@ -0,0 +1,105 @@
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from fancy_gym.envs.mujoco.air_hockey.seven_dof.env_single import AirHockeySingle
|
||||
from fancy_gym.envs.mujoco.air_hockey.utils import inverse_kinematics, forward_kinematics, jacobian
|
||||
|
||||
class AirhocKIT2023BaseEnv(AirHockeySingle):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
obs_low = np.hstack([[-np.inf] * 37])
|
||||
obs_high = np.hstack([[np.inf] * 37])
|
||||
self.wrapper_obs_space = spaces.Box(low=obs_low, high=obs_high, dtype=np.float64)
|
||||
self.wrapper_act_space = spaces.Box(low=np.repeat(-100., 6), high=np.repeat(100., 6))
|
||||
|
||||
# We don't need puck yaw observations
|
||||
def filter_obs(self, obs):
|
||||
obs = np.hstack([obs[0:2], obs[3:5], obs[6:12], obs[13:19], obs[20:]])
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
self.last_acceleration = np.repeat(0., 6)
|
||||
obs = super().reset()
|
||||
self.interp_pos = obs[self.env_info["joint_pos_ids"]][:-1]
|
||||
self.interp_vel = obs[self.env_info["joint_vel_ids"]][:-1]
|
||||
|
||||
self.last_planned_world_pos = self._fk(self.interp_pos)
|
||||
obs = np.hstack([
|
||||
obs, self.interp_pos, self.interp_vel, self.last_acceleration, self.last_planned_world_pos
|
||||
])
|
||||
return self.filter_obs(obs)
|
||||
|
||||
def step(self, action):
|
||||
action /= 10
|
||||
|
||||
new_vel = self.interp_vel + action
|
||||
|
||||
jerk = 2 * (new_vel - self.interp_vel - self.last_acceleration * 0.02) / (0.02 ** 2)
|
||||
new_pos = self.interp_pos + self.interp_vel * 0.02 + (1/2) * self.last_acceleration * (0.02 ** 2) + (1/6) * jerk * (0.02 ** 3)
|
||||
abs_action = np.vstack([np.hstack([new_pos, 0]), np.hstack([new_vel, 0])])
|
||||
|
||||
self.interp_pos = new_pos
|
||||
self.interp_vel = new_vel
|
||||
self.last_acceleration += jerk * 0.02
|
||||
|
||||
obs, rew, done, info = super().step(abs_action)
|
||||
self.last_planned_world_pos = self._fk(self.interp_pos)
|
||||
obs = np.hstack([
|
||||
obs, self.interp_pos, self.interp_vel, self.last_acceleration, self.last_planned_world_pos
|
||||
])
|
||||
|
||||
fatal_rew = self.check_fatal(obs)
|
||||
if fatal_rew != 0:
|
||||
return self.filter_obs(obs), fatal_rew, True, info
|
||||
|
||||
return self.filter_obs(obs), rew, done, info
|
||||
|
||||
def check_constraints(self, constraint_values):
|
||||
fatal_rew = 0
|
||||
|
||||
j_pos_constr = constraint_values["joint_pos_constr"]
|
||||
if j_pos_constr.max() > 0:
|
||||
fatal_rew += j_pos_constr.max()
|
||||
|
||||
j_vel_constr = constraint_values["joint_vel_constr"]
|
||||
if j_vel_constr.max() > 0:
|
||||
fatal_rew += j_vel_constr.max()
|
||||
|
||||
ee_constr = constraint_values["ee_constr"]
|
||||
if ee_constr.max() > 0:
|
||||
fatal_rew += ee_constr.max()
|
||||
|
||||
link_constr = constraint_values["link_constr"]
|
||||
if link_constr.max() > 0:
|
||||
fatal_rew += link_constr.max()
|
||||
|
||||
return -fatal_rew
|
||||
|
||||
def check_fatal(self, obs):
|
||||
fatal_rew = 0
|
||||
|
||||
q = obs[self.env_info["joint_pos_ids"]]
|
||||
qd = obs[self.env_info["joint_vel_ids"]]
|
||||
constraint_values_obs = self.env_info["constraints"].fun(q, qd)
|
||||
fatal_rew += self.check_constraints(constraint_values_obs)
|
||||
|
||||
return -fatal_rew
|
||||
|
||||
def _fk(self, pos):
|
||||
res, _ = forward_kinematics(self.env_info["robot"]["robot_model"],
|
||||
self.env_info["robot"]["robot_data"], pos)
|
||||
return res.astype(np.float32)
|
||||
|
||||
def _ik(self, world_pos, init_q=None):
|
||||
success, pos = inverse_kinematics(self.env_info["robot"]["robot_model"],
|
||||
self.env_info["robot"]["robot_data"],
|
||||
world_pos,
|
||||
initial_q=init_q)
|
||||
pos = pos.astype(np.float32)
|
||||
assert success
|
||||
return pos
|
||||
|
||||
def _jacobian(self, pos):
|
||||
return jacobian(self.env_info["robot"]["robot_model"],
|
||||
self.env_info["robot"]["robot_data"],
|
||||
pos).astype(np.float32)
|
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from fancy_gym.envs.mujoco.air_hockey.seven_dof.env_single import AirHockeySingle
|
||||
from fancy_gym.envs.mujoco.air_hockey.seven_dof.airhockit_base_env import AirhocKIT2023BaseEnv
|
||||
|
||||
|
||||
class AirHockeyDefend(AirHockeySingle):
|
||||
@ -10,9 +11,7 @@ class AirHockeyDefend(AirHockeySingle):
|
||||
"""
|
||||
def __init__(self, gamma=0.99, horizon=500, viewer_params={}):
|
||||
self.init_velocity_range = (1, 3)
|
||||
|
||||
self.start_range = np.array([[0.29, 0.65], [-0.4, 0.4]]) # Table Frame
|
||||
self.init_ee_range = np.array([[0.60, 1.25], [-0.4, 0.4]]) # Robot Frame
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params)
|
||||
|
||||
def setup(self, obs):
|
||||
@ -32,7 +31,7 @@ class AirHockeyDefend(AirHockeySingle):
|
||||
self._write_data("puck_y_vel", puck_vel[1])
|
||||
self._write_data("puck_yaw_vel", puck_vel[2])
|
||||
|
||||
super(AirHockeyDefend, self).setup(obs)
|
||||
super().setup(obs)
|
||||
|
||||
def reward(self, state, action, next_state, absorbing):
|
||||
return 0
|
||||
@ -46,6 +45,98 @@ class AirHockeyDefend(AirHockeySingle):
|
||||
return True
|
||||
return super().is_absorbing(state)
|
||||
|
||||
class AirHockeyDefendAirhocKIT2023(AirhocKIT2023BaseEnv):
|
||||
def __init__(self, gamma=0.99, horizon=200, viewer_params={}):
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params)
|
||||
self.init_velocity_range = (1, 3)
|
||||
self.start_range = np.array([[0.4, 0.75], [-0.4, 0.4]]) # Table Frame
|
||||
self._setup_metrics()
|
||||
|
||||
def setup(self, obs):
|
||||
self._setup_metrics()
|
||||
puck_pos = np.random.rand(2) * (self.start_range[:, 1] - self.start_range[:, 0]) + self.start_range[:, 0]
|
||||
|
||||
lin_vel = np.random.uniform(self.init_velocity_range[0], self.init_velocity_range[1])
|
||||
angle = np.random.uniform(-0.5, 0.5)
|
||||
|
||||
puck_vel = np.zeros(3)
|
||||
puck_vel[0] = -np.cos(angle) * lin_vel
|
||||
puck_vel[1] = np.sin(angle) * lin_vel
|
||||
puck_vel[2] = np.random.uniform(-10, 10)
|
||||
|
||||
self._write_data("puck_x_pos", puck_pos[0])
|
||||
self._write_data("puck_y_pos", puck_pos[1])
|
||||
self._write_data("puck_x_vel", puck_vel[0])
|
||||
self._write_data("puck_y_vel", puck_vel[1])
|
||||
self._write_data("puck_yaw_vel", puck_vel[2])
|
||||
|
||||
super().setup(obs)
|
||||
|
||||
def reset(self, *args):
|
||||
obs = super().reset()
|
||||
self.hit_step_flag = False
|
||||
self.hit_step = False
|
||||
self.received_hit_reward = False
|
||||
self.give_reward_next = False
|
||||
return obs
|
||||
|
||||
def _setup_metrics(self):
|
||||
self.episode_steps = 0
|
||||
self.has_hit = False
|
||||
|
||||
def _simulation_post_step(self):
|
||||
if not self.has_hit:
|
||||
self.has_hit = self._check_collision("puck", "robot_1/ee")
|
||||
|
||||
super()._simulation_post_step()
|
||||
|
||||
def _step_finalize(self):
|
||||
self.episode_steps += 1
|
||||
return super()._step_finalize()
|
||||
|
||||
def reward(self, state, action, next_state, absorbing):
|
||||
puck_pos, puck_vel = self.get_puck(next_state)
|
||||
ee_pos, _ = self.get_ee()
|
||||
rew = 0.01
|
||||
if -0.7 < puck_pos[0] <= -0.2 and np.linalg.norm(puck_vel[:2]) < 0.1:
|
||||
assert absorbing
|
||||
rew += 70
|
||||
|
||||
if self.has_hit and not self.hit_step_flag:
|
||||
self.hit_step_flag = True
|
||||
self.hit_step = True
|
||||
else:
|
||||
self.hit_step = False
|
||||
|
||||
f = lambda puck_vel: 30 + 100 * (100 ** (-0.25 * np.linalg.norm(puck_vel[:2])))
|
||||
if not self.give_reward_next and not self.received_hit_reward and self.hit_step and ee_pos[0] < puck_pos[0]:
|
||||
self.hit_this_step = True
|
||||
if np.linalg.norm(puck_vel[:2]) < 0.1:
|
||||
return rew + f(puck_vel)
|
||||
self.give_reward_next = True
|
||||
return rew
|
||||
|
||||
if not self.received_hit_reward and self.give_reward_next:
|
||||
self.received_hit_reward = True
|
||||
if puck_vel[0] >= -0.2:
|
||||
return rew + f(puck_vel)
|
||||
return rew
|
||||
else:
|
||||
return rew
|
||||
|
||||
def is_absorbing(self, obs):
|
||||
puck_pos, puck_vel = self.get_puck(obs)
|
||||
# If puck is over the middle line and moving towards opponent
|
||||
if puck_pos[0] > 0 and puck_vel[0] > 0:
|
||||
return True
|
||||
|
||||
if self.episode_steps == self._mdp_info.horizon:
|
||||
return True
|
||||
|
||||
if np.linalg.norm(puck_vel[:2]) < 0.1:
|
||||
return True
|
||||
return super().is_absorbing(obs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = AirHockeyDefend()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from fancy_gym.envs.mujoco.air_hockey.seven_dof.env_single import AirHockeySingle
|
||||
from fancy_gym.envs.mujoco.air_hockey.seven_dof.airhockit_base_env import AirhocKIT2023BaseEnv
|
||||
|
||||
|
||||
class AirHockeyHit(AirHockeySingle):
|
||||
@ -14,9 +15,6 @@ class AirHockeyHit(AirHockeySingle):
|
||||
opponent_agent(Agent, None): Agent which controls the opponent
|
||||
moving_init(bool, False): If true, initialize the puck with inital velocity.
|
||||
"""
|
||||
self.hit_range = np.array([[-0.65, -0.25], [-0.4, 0.4]]) # Table Frame
|
||||
self.init_velocity_range = (0, 0.5) # Table Frame
|
||||
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params)
|
||||
|
||||
self.moving_init = moving_init
|
||||
@ -58,6 +56,93 @@ class AirHockeyHit(AirHockeySingle):
|
||||
return True
|
||||
return super(AirHockeyHit, self).is_absorbing(obs)
|
||||
|
||||
class AirHockeyHitAirhocKIT2023(AirhocKIT2023BaseEnv):
|
||||
def __init__(self, gamma=0.99, horizon=500, moving_init=True, viewer_params={}):
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params)
|
||||
|
||||
self.moving_init = moving_init
|
||||
hit_width = self.env_info['table']['width'] / 2 - self.env_info['puck']['radius'] - \
|
||||
self.env_info['mallet']['radius'] * 2
|
||||
self.hit_range = np.array([[-0.7, -0.2], [-hit_width, hit_width]]) # Table Frame
|
||||
self.init_velocity_range = (0, 0.5) # Table Frame
|
||||
self.init_ee_range = np.array([[0.60, 1.25], [-0.4, 0.4]]) # Robot Frame
|
||||
self._setup_metrics()
|
||||
|
||||
def reset(self, *args):
|
||||
obs = super().reset()
|
||||
self.last_ee_pos = self.last_planned_world_pos.copy()
|
||||
self.last_ee_pos[0] -= 1.51
|
||||
return obs
|
||||
|
||||
def setup(self, obs):
|
||||
self._setup_metrics()
|
||||
puck_pos = np.random.rand(2) * (self.hit_range[:, 1] - self.hit_range[:, 0]) + self.hit_range[:, 0]
|
||||
|
||||
self._write_data("puck_x_pos", puck_pos[0])
|
||||
self._write_data("puck_y_pos", puck_pos[1])
|
||||
|
||||
if self.moving_init:
|
||||
lin_vel = np.random.uniform(self.init_velocity_range[0], self.init_velocity_range[1])
|
||||
angle = np.random.uniform(-np.pi / 2 - 0.1, np.pi / 2 + 0.1)
|
||||
puck_vel = np.zeros(3)
|
||||
puck_vel[0] = -np.cos(angle) * lin_vel
|
||||
puck_vel[1] = np.sin(angle) * lin_vel
|
||||
puck_vel[2] = np.random.uniform(-2, 2)
|
||||
|
||||
self._write_data("puck_x_vel", puck_vel[0])
|
||||
self._write_data("puck_y_vel", puck_vel[1])
|
||||
self._write_data("puck_yaw_vel", puck_vel[2])
|
||||
|
||||
super().setup(obs)
|
||||
|
||||
def _setup_metrics(self):
|
||||
self.episode_steps = 0
|
||||
self.has_scored = False
|
||||
|
||||
def _step_finalize(self):
|
||||
cur_obs = self._create_observation(self.obs_helper._build_obs(self._data))
|
||||
puck_pos, _ = self.get_puck(cur_obs) # world frame [x, y, z] and [x', y', z']
|
||||
|
||||
if not self.has_scored:
|
||||
boundary = np.array([self.env_info['table']['length'], self.env_info['table']['width']]) / 2
|
||||
self.has_scored = np.any(np.abs(puck_pos[:2]) > boundary) and puck_pos[0] > 0
|
||||
|
||||
self.episode_steps += 1
|
||||
return super()._step_finalize()
|
||||
|
||||
def reward(self, state, action, next_state, absorbing):
|
||||
rew = 0
|
||||
puck_pos, puck_vel = self.get_puck(next_state)
|
||||
ee_pos, _ = self.get_ee()
|
||||
ee_vel = (ee_pos - self.last_ee_pos) / 0.02
|
||||
self.last_ee_pos = ee_pos
|
||||
|
||||
if puck_vel[0] < 0.25 and puck_pos[0] < 0:
|
||||
ee_puck_dir = (puck_pos - ee_pos)[:2]
|
||||
ee_puck_dir = ee_puck_dir / np.linalg.norm(ee_puck_dir)
|
||||
rew += 1 * max(0, np.dot(ee_puck_dir, ee_vel[:2]))
|
||||
else:
|
||||
rew += 10 * np.linalg.norm(puck_vel[:2])
|
||||
|
||||
if self.has_scored:
|
||||
rew += 2000 + 5000 * np.linalg.norm(puck_vel[:2])
|
||||
|
||||
return rew
|
||||
|
||||
def is_absorbing(self, obs):
|
||||
puck_pos, puck_vel = self.get_puck(obs)
|
||||
# Stop if the puck bounces back on the opponents wall
|
||||
if puck_pos[0] > 0 and puck_vel[0] < 0:
|
||||
return True
|
||||
|
||||
if self.has_scored:
|
||||
return True
|
||||
|
||||
if self.episode_steps == self._mdp_info.horizon:
|
||||
return True
|
||||
|
||||
return super().is_absorbing(obs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = AirHockeyHit(moving_init=True)
|
||||
|
Loading…
Reference in New Issue
Block a user