import copy
import os

import numpy as np
from gym.envs.mujoco.hopper_v3 import HopperEnv

MAX_EPISODE_STEPS_HOPPERJUMP = 250


class HopperJumpEnv(HopperEnv):
    """
    Initialization changes to normal Hopper:
    - terminate_when_unhealthy: True -> False
    - healthy_reward: 1.0 -> 2.0
    - healthy_z_range: (0.7, float('inf')) -> (0.5, float('inf'))
    - healthy_angle_range: (-0.2, 0.2) -> (-float('inf'), float('inf'))
    - exclude_current_positions_from_observation: True -> False
    """

    def __init__(
            self,
            xml_file='hopper_jump.xml',
            forward_reward_weight=1.0,
            ctrl_cost_weight=1e-3,
            healthy_reward=2.0,
            contact_weight=2.0,
            height_weight=10.0,
            dist_weight=3.0,
            terminate_when_unhealthy=False,
            healthy_state_range=(-100.0, 100.0),
            healthy_z_range=(0.5, float('inf')),
            healthy_angle_range=(-float('inf'), float('inf')),
            reset_noise_scale=5e-3,
            exclude_current_positions_from_observation=False,
            sparse=False,
    ):

        self.sparse = sparse
        self._height_weight = height_weight
        self._dist_weight = dist_weight
        self._contact_weight = contact_weight

        self.max_height = 0
        self.goal = np.zeros(3, )

        self._steps = 0
        self.contact_with_floor = False
        self.init_floor_contact = False
        self.has_left_floor = False
        self.contact_dist = None

        xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
        super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
                         healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
                         exclude_current_positions_from_observation)

        # increase initial height
        self.init_qpos[1] = 1.5

    @property
    def exclude_current_positions_from_observation(self):
        return self._exclude_current_positions_from_observation

    def step(self, action):
        self._steps += 1

        self.do_simulation(action, self.frame_skip)

        height_after = self.get_body_com("torso")[2]
        site_pos_after = self.data.get_site_xpos('foot_site')
        self.max_height = max(height_after, self.max_height)

        has_floor_contact = self._is_floor_foot_contact() if not self.contact_with_floor else False

        if not self.init_floor_contact:
            self.init_floor_contact = has_floor_contact
        if self.init_floor_contact and not self.has_left_floor:
            self.has_left_floor = not has_floor_contact
        if not self.contact_with_floor and self.has_left_floor:
            self.contact_with_floor = has_floor_contact

        ctrl_cost = self.control_cost(action)
        costs = ctrl_cost
        done = False

        goal_dist = np.linalg.norm(site_pos_after - self.goal)
        if self.contact_dist is None and self.contact_with_floor:
            self.contact_dist = goal_dist

        rewards = 0
        if not self.sparse or (self.sparse and self._steps >= MAX_EPISODE_STEPS_HOPPERJUMP):
            healthy_reward = self.healthy_reward
            distance_reward = -goal_dist * self._dist_weight
            height_reward = (self.max_height if self.sparse else self.get_body_com("torso")[2]) * self._height_weight
            contact_reward = -(self.contact_dist or 5) * self._contact_weight
            rewards = self._forward_reward_weight * (distance_reward + height_reward + contact_reward + healthy_reward)

        observation = self._get_obs()
        reward = rewards - costs
        info = dict(
            height=height_after,
            x_pos=site_pos_after,
            max_height=self.max_height,
            goal=self.goal[:1],
            goal_dist=goal_dist,
            height_rew=self.max_height,
            healthy_reward=self.healthy_reward * 2,
            healthy=self.is_healthy,
            contact_dist=self.contact_dist or 0
        )
        return observation, reward, done, info

    def _get_obs(self):
        goal_dist = self.data.get_site_xpos('foot_site') - self.goal
        return np.concatenate((super(HopperJumpEnv, self)._get_obs(), goal_dist.copy(), self.goal[:1]))

    def reset_model(self):
        super(HopperJumpEnv, self).reset_model()

        # self.goal = self.np_random.uniform(0.3, 1.35, 1)[0]
        self.goal = np.concatenate([self.np_random.uniform(0.3, 1.35, 1), np.zeros(2, )])
        self.sim.model.body_pos[self.sim.model.body_name2id('goal_site_body')] = self.goal
        self.max_height = 0
        self._steps = 0

        noise_low = -np.zeros(self.model.nq)
        noise_low[3] = -0.5
        noise_low[4] = -0.2
        noise_low[5] = 0

        noise_high = np.zeros(self.model.nq)
        noise_high[3] = 0
        noise_high[4] = 0
        noise_high[5] = 0.785

        qpos = (
                self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) +
                self.init_qpos
        )
        qvel = (
            # self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv) +
            self.init_qvel
        )

        self.set_state(qpos, qvel)

        observation = self._get_obs()
        self.has_left_floor = False
        self.contact_with_floor = False
        self.init_floor_contact = False
        self.contact_dist = None

        return observation

    def _is_floor_foot_contact(self):
        floor_geom_id = self.model.geom_name2id('floor')
        foot_geom_id = self.model.geom_name2id('foot_geom')
        for i in range(self.data.ncon):
            contact = self.data.contact[i]
            collision = contact.geom1 == floor_geom_id and contact.geom2 == foot_geom_id
            collision_trans = contact.geom1 == foot_geom_id and contact.geom2 == floor_geom_id
            if collision or collision_trans:
                return True
        return False


class HopperJumpStepEnv(HopperJumpEnv):

    def __init__(self,
                 xml_file='hopper_jump.xml',
                 forward_reward_weight=1.0,
                 ctrl_cost_weight=1e-3,
                 healthy_reward=1.0,
                 height_weight=3,
                 dist_weight=3,
                 terminate_when_unhealthy=False,
                 healthy_state_range=(-100.0, 100.0),
                 healthy_z_range=(0.5, float('inf')),
                 healthy_angle_range=(-float('inf'), float('inf')),
                 reset_noise_scale=5e-3,
                 exclude_current_positions_from_observation=False
                 ):

        self._height_weight = height_weight
        self._dist_weight = dist_weight
        super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
                         healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
                         exclude_current_positions_from_observation)

    def step(self, action):
        self._steps += 1

        self.do_simulation(action, self.frame_skip)

        height_after = self.get_body_com("torso")[2]
        site_pos_after = self.data.get_site_xpos('foot_site')
        self.max_height = max(height_after, self.max_height)

        ctrl_cost = self.control_cost(action)
        healthy_reward = self.healthy_reward
        height_reward = self._height_weight * height_after
        goal_dist = np.linalg.norm(site_pos_after - np.array([self.goal, 0, 0]))
        goal_dist_reward = -self._dist_weight * goal_dist
        dist_reward = self._forward_reward_weight * (goal_dist_reward + height_reward)

        rewards = dist_reward + healthy_reward
        costs = ctrl_cost
        done = False

        # This is only for logging the distance to goal when first having the contact
        has_floor_contact = self._is_floor_foot_contact() if not self.contact_with_floor else False

        if not self.init_floor_contact:
            self.init_floor_contact = has_floor_contact
        if self.init_floor_contact and not self.has_left_floor:
            self.has_left_floor = not has_floor_contact
        if not self.contact_with_floor and self.has_left_floor:
            self.contact_with_floor = has_floor_contact

        if self.contact_dist is None and self.contact_with_floor:
            self.contact_dist = goal_dist

        ##############################################################

        observation = self._get_obs()
        reward = rewards - costs
        info = {
            'height': height_after,
            'x_pos': site_pos_after,
            'max_height': copy.copy(self.max_height),
            'goal': copy.copy(self.goal),
            'goal_dist': goal_dist,
            'height_rew': height_reward,
            'healthy_reward': healthy_reward,
            'healthy': copy.copy(self.is_healthy),
            'contact_dist': copy.copy(self.contact_dist) or 0
        }
        return observation, reward, done, info