More configurability
This commit is contained in:
		
							parent
							
								
									5afa8b22b2
								
							
						
					
					
						commit
						3bb3ffa3a0
					
				| @ -19,6 +19,7 @@ class Entity(object): | ||||
|         self.col = (255, 255, 255) | ||||
|         self.solid = False | ||||
|         self.movable = False  # False = Non movable, True = Movable, x>1: lighter movable | ||||
|         self.void_collidable = True | ||||
|         self.elasticity = 1 | ||||
|         self.collision_changes_speed = self.env.controll_type == 'ACC' | ||||
|         self.collision_elasticity = self.env.default_collision_elasticity | ||||
| @ -31,6 +32,8 @@ class Entity(object): | ||||
|         self.draw_path_harm = False | ||||
|         self.draw_path_harm_col = [c for c in self.draw_path_col] | ||||
|         self.draw_path_harm_col[0] += int(255/3) | ||||
|         self.min_speed = 0 | ||||
|         self.max_speed = math.inf | ||||
| 
 | ||||
|     def __post_init__(self): | ||||
|         pass | ||||
| @ -40,6 +43,11 @@ class Entity(object): | ||||
|         vx, vy = self.speed | ||||
|         ax, ay = self.acc | ||||
|         vx, vy = vx+ax*self.env.acc_fac,  vy+ay*self.env.acc_fac | ||||
|         speeds = math.sqrt(vx**2 + vy**2) | ||||
|         if speeds < self.min_speed: | ||||
|             vx, vy = vx/speeds*self.min_speed, vy/speeds*self.min_speed | ||||
|         if speeds > self.max_speed: | ||||
|             vx, vy = vx/speeds*self.max_speed, vy/speeds*self.max_speed | ||||
|         x, y = x+vx*self.env.speed_fac, y+vy*self.env.speed_fac | ||||
|         if not self.env.torus_topology: | ||||
|             if x > 1 or x < 0: | ||||
| @ -73,7 +81,7 @@ class Entity(object): | ||||
|             pygame.draw.line(self.env.path_overlay, col, | ||||
|                              (self.last_pos[0]*self.env.width, self.last_pos[1]*self.env.height), (self.pos[0]*self.env.width, self.pos[1]*self.env.height), self.draw_path_width) | ||||
|             pygame.draw.circle(self.env.path_overlay, col, | ||||
|                                (self.pos[0]*self.env.width, self.pos[1]*self.env.height), max(0, self.draw_path_width/2-1)) | ||||
|                                (self.pos[0]*self.env.width, self.pos[1]*self.env.height), max(0, self.draw_path_width/2-3)) | ||||
|         self.last_pos = self.pos[0], self.pos[1] | ||||
| 
 | ||||
|     def on_collision(self, other, depth): | ||||
| @ -92,9 +100,16 @@ class Entity(object): | ||||
|             return | ||||
|         force_dir = force_dir[0]/force_dir_len, force_dir[1]/force_dir_len | ||||
|         if not self.env.torus_topology: | ||||
|             if self.env.agent.pos[0] > 0.99 or self.env.agent.pos[0] < 0.01: | ||||
|             if self == self.env.agent: | ||||
|                 agent = self | ||||
|             elif other == self.env.agent: | ||||
|                 agent = other | ||||
|             else: | ||||
|                 agent = None | ||||
|             if agent: | ||||
|                 if agent.pos[0] > 0.99 or agent.pos[0] < 0.01: | ||||
|                     force_dir = force_dir[0], force_dir[1] * 2 | ||||
|             if self.env.agent.pos[1] > 0.99 or self.env.agent.pos[1] < 0.01: | ||||
|                 if agent.pos[1] > 0.99 or agent.pos[1] < 0.01: | ||||
|                     force_dir = force_dir[0] * 2, force_dir[1] | ||||
|         depth *= 1.0*self.movable/(self.movable + other.movable)/2 | ||||
|         depth /= other.elasticity | ||||
| @ -139,6 +154,24 @@ class Entity(object): | ||||
|     def kill(self): | ||||
|         self.env.kill_entity(self) | ||||
| 
 | ||||
|     def getQuasiRadius(self): | ||||
|         raise Exception() | ||||
| 
 | ||||
|     def getTop(self): | ||||
|         raise Exception() | ||||
| 
 | ||||
|     def getBottom(self): | ||||
|         raise Exception() | ||||
| 
 | ||||
|     def getLeft(self): | ||||
|         raise Exception() | ||||
| 
 | ||||
|     def getRight(self): | ||||
|         raise Exception() | ||||
| 
 | ||||
|     def getCenter(self): | ||||
|         raise Exception() | ||||
| 
 | ||||
| 
 | ||||
| class CircularEntity(Entity): | ||||
|     def __init__(self, env): | ||||
| @ -208,6 +241,24 @@ class CircularEntity(Entity): | ||||
|             raise Exception( | ||||
|                 '[!] Shape "circle" does not know how to collide with shape "'+str(other.shape)+'"') | ||||
| 
 | ||||
|     def getQuasiRadius(self): | ||||
|         return self.radius | ||||
| 
 | ||||
|     def getTop(self): | ||||
|         return self.pos[1]*self.env.height - self.radius | ||||
| 
 | ||||
|     def getBottom(self): | ||||
|         return self.pos[1]*self.env.height + self.radius | ||||
| 
 | ||||
|     def getLeft(self): | ||||
|         return self.pos[0]*self.env.width - self.radius | ||||
| 
 | ||||
|     def getRight(self): | ||||
|         return self.pos[0]*self.env.width + self.radius | ||||
| 
 | ||||
|     def getCenter(self): | ||||
|         return self.pos[0]*self.env.width, self.pos[1]*self.env.height | ||||
| 
 | ||||
| 
 | ||||
| class RectangularEntity(Entity): | ||||
|     def __init__(self, env): | ||||
| @ -228,6 +279,58 @@ class RectangularEntity(Entity): | ||||
|         raise Exception( | ||||
|             '[!] Collisions in this direction not implemented for shape "rectangle"') | ||||
| 
 | ||||
|     def physics_step(self): | ||||
|         x, y = self.pos | ||||
|         vx, vy = self.speed | ||||
|         ax, ay = self.acc | ||||
|         vx, vy = vx+ax*self.env.acc_fac,  vy+ay*self.env.acc_fac | ||||
|         speeds = math.sqrt(vx**2 + vy**2) | ||||
|         if speeds < self.min_speed: | ||||
|             vx, vy = vx/speeds*self.min_speed, vy/speeds*self.min_speed | ||||
|         if speeds > self.max_speed: | ||||
|             vx, vy = vx/speeds*self.max_speed, vy/speeds*self.max_speed | ||||
|         x, y = x+vx*self.env.speed_fac, y+vy*self.env.speed_fac | ||||
|         if not self.env.torus_topology and self.void_collidable: | ||||
|             if x+(self.width/self.env.width) > 1 or x < 0: | ||||
|                 if x < 0: | ||||
|                     x, y, vx, vy = self.calc_void_collision( | ||||
|                         x < 0, x, y, vx, vy) | ||||
|                 else: | ||||
|                     x, y, vx, vy = self.calc_void_collision( | ||||
|                         x < 0, x+(self.width/self.env.width), y, vx, vy) | ||||
|                     x -= (self.width/self.env.width) | ||||
|             if y+(self.height/self.env.height) > 1 or y < 0: | ||||
|                 if y < 0: | ||||
|                     x, y, vx, vy = self.calc_void_collision( | ||||
|                         2 + (x < 0), x, y, vx, vy) | ||||
|                 else: | ||||
|                     x, y, vx, vy = self.calc_void_collision( | ||||
|                         2 + (x < 0), x, y+(self.height/self.env.height), vx, vy) | ||||
|                     y -= (self.height/self.env.height) | ||||
|         else: | ||||
|             x = x % 1 | ||||
|             y = y % 1 | ||||
|         self.speed = vx/(1+self.drag), vy/(1+self.drag) | ||||
|         self.pos = x, y | ||||
| 
 | ||||
|     def getQuasiRadius(self): | ||||
|         return self.width + self.height | ||||
| 
 | ||||
|     def getTop(self): | ||||
|         return self.pos[1]*self.env.height | ||||
| 
 | ||||
|     def getBottom(self): | ||||
|         return self.pos[1]*self.env.height + self.height | ||||
| 
 | ||||
|     def getLeft(self): | ||||
|         return self.pos[0]*self.env.width | ||||
| 
 | ||||
|     def getRight(self): | ||||
|         return self.pos[0]*self.env.width*self.env.height + self.width | ||||
| 
 | ||||
|     def getCenter(self): | ||||
|         return self.pos[0]*self.env.width+self.width/2, self.pos[1]*self.env.height+self.height/2 | ||||
| 
 | ||||
| 
 | ||||
| class Agent(CircularEntity): | ||||
|     def __init__(self, env): | ||||
| @ -252,6 +355,30 @@ class Agent(CircularEntity): | ||||
|             raise Exception('Unsupported controll_type') | ||||
| 
 | ||||
| 
 | ||||
| # Does not work! Don't use! | ||||
| class PongAgent(RectangularEntity): | ||||
|     def __init__(self, env): | ||||
|         super(PongAgent, self).__init__(env) | ||||
|         self.pos = (0.5, 0.5) | ||||
|         self.col = (0, 0, 255) | ||||
|         self.drag = self.env.agent_drag | ||||
|         self.controll_type = self.env.controll_type | ||||
|         self.solid = True | ||||
|         self.movable = True | ||||
| 
 | ||||
|     def controll_step(self): | ||||
|         self._read_input() | ||||
|         self.env.check_collisions_for(self) | ||||
| 
 | ||||
|     def _read_input(self): | ||||
|         if self.controll_type == 'SPEED': | ||||
|             self.speed = 0, self.env.inp[1] - 0.5 | ||||
|         elif self.controll_type == 'ACC': | ||||
|             self.acc = 0, self.env.inp[1] - 0.5 | ||||
|         else: | ||||
|             raise Exception('Unsupported controll_type') | ||||
| 
 | ||||
| 
 | ||||
| class Enemy(Entity): | ||||
|     def __init__(self, env): | ||||
|         super(Enemy, self).__init__(env) | ||||
| @ -349,6 +476,33 @@ class Collectable(CircularEntity): | ||||
|             self.env.check_collisions_for(self) | ||||
| 
 | ||||
| 
 | ||||
| class RectCollectable(RectangularEntity): | ||||
|     def __init__(self, env): | ||||
|         super(RectCollectable, self).__init__(env) | ||||
|         self.avaible = True | ||||
|         self.enforce_not_on_barrier = False | ||||
|         self.reward = 10 | ||||
|         self.collectors = [] | ||||
| 
 | ||||
|     def on_collision(self, other, depth): | ||||
|         super().on_collision(other, depth) | ||||
|         if isinstance(other, Barrier): | ||||
|             self.on_barrier_collision() | ||||
|         else: | ||||
|             for Col in self.collectors: | ||||
|                 if isinstance(other, Col): | ||||
|                     other.on_collect(self) | ||||
|                     self.on_collected() | ||||
| 
 | ||||
|     def on_collected(self): | ||||
|         self.env.new_reward += self.reward | ||||
| 
 | ||||
|     def on_barrier_collision(self): | ||||
|         if self.enforce_not_on_barrier: | ||||
|             self.pos = (self.env.random(), self.env.random()) | ||||
|             self.env.check_collisions_for(self) | ||||
| 
 | ||||
| 
 | ||||
| class Reward(Collectable): | ||||
|     def __init__(self, env): | ||||
|         super(Reward, self).__init__(env) | ||||
| @ -479,6 +633,14 @@ class Goal(Collectable): | ||||
|         self.collectors = [Ball] | ||||
| 
 | ||||
| 
 | ||||
| class RectGoal(RectCollectable): | ||||
|     def __init__(self, env): | ||||
|         super(RectGoal, self).__init__(env) | ||||
|         self.col = (0, 200, 0) | ||||
|         self.reward = 500 | ||||
|         self.collectors = [Ball] | ||||
| 
 | ||||
| 
 | ||||
| class TeleportingGoal(Goal): | ||||
|     def __init__(self, env): | ||||
|         super(TeleportingGoal, self).__init__(env) | ||||
|  | ||||
| @ -15,7 +15,7 @@ from columbus.utils import soft_int, parseObs | ||||
| class ColumbusEnv(gym.Env): | ||||
|     metadata = {'render.modes': ['human']} | ||||
| 
 | ||||
|     def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1, master_seed=None, start_pos=(0.5, 0.5), start_score=0, speed_fac=0.01, acc_fac=0.04, die_on_zero=False, return_on_score=-1, reward_mult=1, agent_drag=0, controll_type='SPEED', aux_reward_max=1, aux_penalty_max=0, aux_reward_discretize=0, void_is_type_barrier=True, void_damage=1, torus_topology=False, default_collision_elasticity=1, terminate_on_reward=False, agent_draw_path=False, clear_path_on_reset=True, max_steps=-1, value_color_mapper='tanh', width=720, height=720, agent_attrs={}): | ||||
|     def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1, master_seed=None, start_pos=(0.5, 0.5), start_score=0, speed_fac=0.01, acc_fac=0.04, die_on_zero=False, return_on_score=-1, reward_mult=1, agent_drag=0, controll_type='SPEED', aux_reward_max=1, aux_penalty_max=0, aux_reward_discretize=0, void_is_type_barrier=True, void_damage=1, torus_topology=False, default_collision_elasticity=1, terminate_on_reward=False, agent_draw_path=False, clear_path_on_reset=True, max_steps=-1, value_color_mapper='tanh', width=720, height=720, agent_attrs={}, agent_cls=entities.Agent, exception_for_unsupported_collision=True): | ||||
|         super(ColumbusEnv, self).__init__() | ||||
|         self.action_space = spaces.Box( | ||||
|             low=-1, high=1, shape=(2,), dtype=np.float32) | ||||
| @ -63,8 +63,13 @@ class ColumbusEnv(gym.Env): | ||||
|         self.clear_path_on_reset = clear_path_on_reset | ||||
|         self.path_decay = 0.1 | ||||
| 
 | ||||
|         if isinstance(agent_cls, str): | ||||
|             agent_cls = getattr(entities, agent_cls) | ||||
|         self.Agent_cls = agent_cls | ||||
|         self.agent_attrs = agent_attrs | ||||
| 
 | ||||
|         self.exception_for_unsupported_collision = exception_for_unsupported_collision | ||||
| 
 | ||||
|         if value_color_mapper == 'atan': | ||||
|             def value_color_mapper(x): return th.atan(x*2)/0.786/2 | ||||
|         elif value_color_mapper == 'tanh': | ||||
| @ -217,8 +222,8 @@ class ColumbusEnv(gym.Env): | ||||
|         reward, self.new_reward, self.new_abs_reward = self.new_reward / \ | ||||
|             self.fps + self.new_abs_reward, 0, 0 | ||||
|         if not self.torus_topology: | ||||
|             if self.agent.pos[0] < 0.001 or self.agent.pos[0] > 0.999 \ | ||||
|                     or self.agent.pos[1] < 0.001 or self.agent.pos[1] > 0.999: | ||||
|             if self.agent.getTop() < 1 or self.agent.getBottom() > self.height-1 \ | ||||
|                     or self.agent.getLeft() < 1 or self.agent.getRight() > self.width-1: | ||||
|                 reward -= self.void_damage/self.fps | ||||
|         self.score += reward  # aux_reward does not count towards the score | ||||
|         if self.aux_reward_max or self.aux_penalty_max: | ||||
| @ -252,8 +257,10 @@ class ColumbusEnv(gym.Env): | ||||
|         elif shapes == ['circle', 'rect']: | ||||
|             return sum([abs(d) for d in e1._get_crash_force_dir(e2)]) | ||||
|         else: | ||||
|             if self.exception_for_unsupported_collision: | ||||
|                 raise Exception( | ||||
|                     'Checking for collision between unsupported shapes: '+str(shapes)) | ||||
|             return 0.0 | ||||
| 
 | ||||
|     def kill_entity(self, target): | ||||
|         newEntities = [] | ||||
| @ -270,7 +277,7 @@ class ColumbusEnv(gym.Env): | ||||
|         # Expand this function | ||||
| 
 | ||||
|     def _spawnAgent(self): | ||||
|         self.agent = entities.Agent(self) | ||||
|         self.agent = self.Agent_cls(self) | ||||
|         self.agent.draw_path = self.agent_draw_path | ||||
|         for k, v in self.agent_attrs.items(): | ||||
|             setattr(self.agent, k, v) | ||||
| @ -497,11 +504,11 @@ class ColumbusConfigDefined(ColumbusEnv): | ||||
|     def is_unit(self, s): | ||||
|         if type(s) in [int, float]: | ||||
|             return True | ||||
|         if s.replace('.', '', 1).isdigit(): | ||||
|         if s.replace('.', '', 1).replace('-', '0', 1).isdigit(): | ||||
|             return True | ||||
|         num, unit = s[:-2], s[-2:] | ||||
|         if unit in ['px', 'em', 'rx', 'ry', 'ct', 'au']: | ||||
|             if num.replace('.', '', 1).isdigit(): | ||||
|             if num.replace('.', '', 1).replace('-', '0', 1).isdigit(): | ||||
|                 return True | ||||
|         return False | ||||
| 
 | ||||
| @ -556,11 +563,32 @@ class ColumbusConfigDefined(ColumbusEnv): | ||||
|                     else: | ||||
|                         v = v_raw | ||||
|                     if k.endswith('_rand'): | ||||
|                         if isinstance(v, int): | ||||
|                             n = k.replace('_rand', '') | ||||
|                             cur = getattr( | ||||
|                                 entity, n) | ||||
|                             inc = int((v+0.99)*self.random()) | ||||
|                             setattr(entity, n, cur + inc) | ||||
|                         elif isinstance(v, float): | ||||
|                             n = k.replace('_randf', '') | ||||
|                             cur = getattr( | ||||
|                                 entity, n) | ||||
|                             inc = v*self.random() | ||||
|                             setattr(entity, n, cur + inc) | ||||
|                         elif isinstance(v, list): | ||||
|                             for vi, ve in enumerate(v): | ||||
|                                 if isinstance(v, int): | ||||
|                                     n = k.replace('_rand', '') | ||||
|                                     cur = getattr( | ||||
|                                         entity, n) | ||||
|                                     cur[vi] = int((v+0.99)*self.random()) | ||||
|                                     setattr(entity, n, cur) | ||||
|                                 elif isinstance(v, float): | ||||
|                                     n = k.replace('_randf', '') | ||||
|                                     cur = getattr( | ||||
|                                         entity, n) | ||||
|                                     cur[vi] = v*self.random() | ||||
|                                     setattr(entity, n, cur) | ||||
|                     elif k.endswith('_randf'): | ||||
|                         n = k.replace('_randf', '') | ||||
|                         cur = getattr( | ||||
|  | ||||
| @ -151,7 +151,7 @@ class RayObservable(Observable): | ||||
|                     'Can only raycast circular and rectangular entities!') | ||||
|             sq_dist = ((self.env.agent.pos[0]-x)*self.env.width) ** 2 \ | ||||
|                 + ((self.env.agent.pos[1]-y)*self.env.height) ** 2 | ||||
|             if sq_dist <= (radius + self.env.agent.radius + self.ray_len)**2: | ||||
|             if sq_dist <= (radius + self.env.agent.getQuasiRadius() + self.ray_len)**2: | ||||
|                 entities_l.append(entity)  # cannot use yield here! | ||||
|         return entities_l | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user