Compare commits

...

2 Commits

2 changed files with 7 additions and 3 deletions

View File

@ -195,6 +195,7 @@ class TimeoutReward(OnceReward):
self.env.timers.append((self.timeout, self.set_avaible, True))
# Not a real entity. Is used in the config of RayObserver to reference the outer boundary of the environment.
class Void():
def __init__(self, env):
self.col = (50, 50, 50)

View File

@ -109,7 +109,7 @@ class RayObservable(Observable):
for entity in entities_l:
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
if isinstance(entity, entities.Void):
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height:
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width:
return True
else:
if entity.shape != 'circle':
@ -169,7 +169,7 @@ class RayObservable(Observable):
class StateObservable(Observable):
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True, include_rand=True):
super(StateObservable, self).__init__()
self._entities = None
self._timeoutEntities = []
@ -182,6 +182,7 @@ class StateObservable(Observable):
self.enemysWhitelist = enemysWhitelist
self.enemysNoBarriers = enemysNoBarriers
self.rewardsTimeouts = rewardsTimeouts
self.include_rand = include_rand
@property
def entities(self):
@ -210,7 +211,7 @@ class StateObservable(Observable):
def get_observation_space(self):
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent,), dtype=np.float32)
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32)
def get_observation(self):
obs = []
@ -232,6 +233,8 @@ class StateObservable(Observable):
if self.speedAgent:
obs.append(self.env.speed[0])
obs.append(self.env.speed[1])
if self.include_rand:
obs.append(self.env.random())
self.obs = obs
return np.array(obs)