added ABC to reacher base envs

This commit is contained in:
Maximilian Huettenrauch 2021-11-30 10:33:04 +01:00
parent 66aa0ab9e5
commit cfa49a04ba
3 changed files with 10 additions and 9 deletions

View File

@ -1,5 +1,5 @@
from typing import Iterable, Union
from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
import gym
import matplotlib.pyplot as plt
import numpy as np
@ -8,7 +8,7 @@ from gym.utils import seeding
from alr_envs.alr.classic_control.utils import intersect
class BaseReacherEnv(gym.Env):
class BaseReacherEnv(gym.Env, ABC):
"""
Base class for all reaching environments.
"""

View File

@ -1,9 +1,11 @@
from abc import ABC
from gym import spaces
import numpy as np
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv
class BaseReacherDirectEnv(BaseReacherEnv):
class BaseReacherDirectEnv(BaseReacherEnv, ABC):
"""
Base class for directly controlled reaching environments
"""
@ -11,7 +13,7 @@ class BaseReacherDirectEnv(BaseReacherEnv):
allow_self_collision: bool = False):
super().__init__(n_links, random_start, allow_self_collision)
self.max_vel = 10 * np.pi
self.max_vel = 2 * np.pi
action_bound = np.ones((self.n_links,)) * self.max_vel
self.action_space = spaces.Box(low=-action_bound, high=action_bound, shape=action_bound.shape)

View File

@ -1,9 +1,11 @@
from abc import ABC
from gym import spaces
import numpy as np
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv
class BaseReacherTorqueEnv(BaseReacherEnv):
class BaseReacherTorqueEnv(BaseReacherEnv, ABC):
"""
Base class for torque controlled reaching environments
"""
@ -20,10 +22,7 @@ class BaseReacherTorqueEnv(BaseReacherEnv):
A single step with action in torque space
"""
# action = self._add_action_noise(action)
ac = np.clip(action, -self.max_torque, self.max_torque)
self._angle_velocity = self._angle_velocity + self.dt * ac
self._angle_velocity = self._angle_velocity + self.dt * action
self._joint_angles = self._joint_angles + self.dt * self._angle_velocity
self._update_joints()