New Observable: Compass
This commit is contained in:
parent
f18310ed5b
commit
8706462358
@ -10,6 +10,35 @@ from columbus import entities, observables
|
|||||||
import torch as th
|
import torch as th
|
||||||
|
|
||||||
|
|
||||||
|
def parseObs(obsConf):
|
||||||
|
if type(obsConf) == list:
|
||||||
|
obs = []
|
||||||
|
for i, c in enumerate(obsConf):
|
||||||
|
obs.append(parseObs(c))
|
||||||
|
if len(obs) == 1:
|
||||||
|
return obs[0]
|
||||||
|
else:
|
||||||
|
return observables.CompositionalObservable(obs)
|
||||||
|
|
||||||
|
if obsConf['type'] == 'State':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.StateObservable(conf)
|
||||||
|
elif obsConf['type'] == 'Compass':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.CompassObservable(**conf)
|
||||||
|
elif obsConf['type'] == 'RayCast':
|
||||||
|
chans = []
|
||||||
|
for chan in obsConf.get('chans', []):
|
||||||
|
chans.append(getattr(entities, chan))
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']}
|
||||||
|
return observables.RayObservable(chans=chans, **conf)
|
||||||
|
elif obsConf['type'] == 'CNN':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.CnnObservable(**conf)
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown Observable selected')
|
||||||
|
|
||||||
|
|
||||||
class ColumbusEnv(gym.Env):
|
class ColumbusEnv(gym.Env):
|
||||||
metadata = {'render.modes': ['human']}
|
metadata = {'render.modes': ['human']}
|
||||||
|
|
||||||
@ -527,6 +556,30 @@ class ColumbusStateWithBarriers(ColumbusEnv):
|
|||||||
self.entities.append(reward)
|
self.entities.append(reward)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumbusCompassWithBarriers(ColumbusEnv):
|
||||||
|
def __init__(self, observable=observables.CompassObservable(coordsRewards=True), fps=30, env_seed=3.141, num_enemys=0, num_barriers=3, aux_reward_max=10, **kw):
|
||||||
|
super().__init__(
|
||||||
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
|
||||||
|
self.start_pos = (0.5, 0.5)
|
||||||
|
self.num_barriers = num_barriers
|
||||||
|
self.num_enemys = num_enemys
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.agent.pos = self.start_pos
|
||||||
|
for i in range(self.num_barriers):
|
||||||
|
enemy = entities.CircleBarrier(self)
|
||||||
|
enemy.radius = self.random()*25+75
|
||||||
|
self.entities.append(enemy)
|
||||||
|
for i in range(self.num_enemys):
|
||||||
|
enemy = entities.FlyingChaser(self)
|
||||||
|
enemy.chase_acc = 0.55 # *0.6+0.5
|
||||||
|
self.entities.append(enemy)
|
||||||
|
for i in range(1):
|
||||||
|
reward = entities.TeleportingReward(self)
|
||||||
|
reward.radius = 30
|
||||||
|
self.entities.append(reward)
|
||||||
|
|
||||||
|
|
||||||
class ColumbusTrivialRay(ColumbusStateWithBarriers):
|
class ColumbusTrivialRay(ColumbusStateWithBarriers):
|
||||||
def __init__(self, observable=observables.RayObservable(num_rays=8, ray_len=512), hide_map=False, fps=30, **kw):
|
def __init__(self, observable=observables.RayObservable(num_rays=8, ray_len=512), hide_map=False, fps=30, **kw):
|
||||||
super(ColumbusTrivialRay, self).__init__(
|
super(ColumbusTrivialRay, self).__init__(
|
||||||
@ -558,32 +611,6 @@ class ColumbusFootball(ColumbusEnv):
|
|||||||
self.entities.append(entities.FlyingFootballPlayer(self, ball))
|
self.entities.append(entities.FlyingFootballPlayer(self, ball))
|
||||||
|
|
||||||
|
|
||||||
def parseObs(obsConf):
|
|
||||||
if type(obsConf) == list:
|
|
||||||
obs = []
|
|
||||||
for i, c in enumerate(obsConf):
|
|
||||||
obs.append(parseObs(c))
|
|
||||||
if len(obs) == 1:
|
|
||||||
return obs[0]
|
|
||||||
else:
|
|
||||||
return observables.CompositionalObservable(obs)
|
|
||||||
|
|
||||||
if obsConf['type'] == 'State':
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
|
||||||
return observables.StateObservable(conf)
|
|
||||||
elif obsConf['type'] == 'RayCast':
|
|
||||||
chans = []
|
|
||||||
for chan in obsConf.get('chans', []):
|
|
||||||
chans.append(getattr(entities, chan))
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']}
|
|
||||||
return observables.RayObservable(chans=chans, **conf)
|
|
||||||
elif obsConf['type'] == 'CNN':
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
|
||||||
return observables.CnnObservable(**conf)
|
|
||||||
else:
|
|
||||||
raise Exception('Unknown Observable selected')
|
|
||||||
|
|
||||||
|
|
||||||
class ColumbusConfigDefined(ColumbusEnv):
|
class ColumbusConfigDefined(ColumbusEnv):
|
||||||
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
|
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -673,6 +700,12 @@ register(
|
|||||||
max_episode_steps=30*60*2,
|
max_episode_steps=30*60*2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='ColumbusCompassWithBarriers-v0',
|
||||||
|
entry_point=ColumbusCompassWithBarriers,
|
||||||
|
max_episode_steps=30*60*2,
|
||||||
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ColumbusTrivialRay-v0',
|
id='ColumbusTrivialRay-v0',
|
||||||
entry_point=ColumbusTrivialRay,
|
entry_point=ColumbusTrivialRay,
|
||||||
|
@ -265,6 +265,71 @@ class StateObservable(Observable):
|
|||||||
(x*self.env.width+ofs[1], 0), 1, width=0)
|
(x*self.env.width+ofs[1], 0), 1, width=0)
|
||||||
|
|
||||||
|
|
||||||
|
class CompassObservable(Observable):
|
||||||
|
def __init__(self, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True):
|
||||||
|
super().__init__()
|
||||||
|
self._entities = None
|
||||||
|
self._timeoutEntities = []
|
||||||
|
self.coordRewards = coordsRewards
|
||||||
|
self.rewardsWhitelist = rewardsWhitelist
|
||||||
|
self.coordsEnemys = coordsEnemys
|
||||||
|
self.enemysWhitelist = enemysWhitelist
|
||||||
|
self.enemysNoBarriers = enemysNoBarriers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities(self):
|
||||||
|
if not self._entities == None:
|
||||||
|
return self._entities
|
||||||
|
rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
||||||
|
enemysWhitelist = self.enemysWhitelist or self.env.entities
|
||||||
|
self._entities = []
|
||||||
|
if self.coordRewards:
|
||||||
|
for entity in rewardsWhitelist:
|
||||||
|
if isinstance(entity, entities.Reward):
|
||||||
|
self._entities.append(entity)
|
||||||
|
if self.coordsEnemys:
|
||||||
|
for entity in enemysWhitelist:
|
||||||
|
if isinstance(entity, entities.Enemy):
|
||||||
|
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
||||||
|
self._entities.append(entity)
|
||||||
|
return self._entities
|
||||||
|
|
||||||
|
def get_observation_space(self):
|
||||||
|
self.env.reset()
|
||||||
|
num = len(self.entities)*2
|
||||||
|
return spaces.Box(low=-1, high=1,
|
||||||
|
shape=(num,), dtype=np.float32)
|
||||||
|
|
||||||
|
def get_observation(self):
|
||||||
|
obs = []
|
||||||
|
for entity in self.entities:
|
||||||
|
dx, dy = entity.pos[0] - \
|
||||||
|
self.env.agent.pos[0], entity.pos[1] - self.env.agent.pos[1]
|
||||||
|
l = math.sqrt(dx**2 + dy**2)*2
|
||||||
|
x, y = math.tanh(dx/l), math.tanh(dy/l)
|
||||||
|
obs.append(x)
|
||||||
|
obs.append(y)
|
||||||
|
|
||||||
|
self.obs = obs
|
||||||
|
return np.array(obs)
|
||||||
|
|
||||||
|
def draw(self):
|
||||||
|
ofs = (0 + self.env.height/2,
|
||||||
|
0 + self.env.width/2)
|
||||||
|
if True:
|
||||||
|
pygame.draw.circle(self.env.screen, self.env.agent.col,
|
||||||
|
(0, self.env.height/2), 3, width=0)
|
||||||
|
pygame.draw.circle(self.env.screen, self.env.agent.col,
|
||||||
|
(self.env.width/2, 0), 3, width=0)
|
||||||
|
for i in range(int(len(self.obs)/2)):
|
||||||
|
x, y = self.obs[i*2], self.obs[i*2+1]
|
||||||
|
col = self.entities[i].col
|
||||||
|
pygame.draw.circle(self.env.screen, col,
|
||||||
|
(0, y*self.env.height+ofs[0]), 1, width=0)
|
||||||
|
pygame.draw.circle(self.env.screen, col,
|
||||||
|
(x*self.env.width+ofs[1], 0), 1, width=0)
|
||||||
|
|
||||||
|
|
||||||
class CompositionalObservable(Observable):
|
class CompositionalObservable(Observable):
|
||||||
def __init__(self, observables):
|
def __init__(self, observables):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user