From b0aeb94cd71b281b97ac423f9019e8f774d3bd5c Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 21 Jun 2022 21:38:18 +0200 Subject: [PATCH] Fixed Bug RayObservable did not detect lower Void --- columbus/observables.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/columbus/observables.py b/columbus/observables.py index cbd4547..bd9407b 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -109,7 +109,7 @@ class RayObservable(Observable): for entity in entities_l: if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy): if isinstance(entity, entities.Void): - if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height: + if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width: return True else: if entity.shape != 'circle': @@ -169,7 +169,7 @@ class RayObservable(Observable): class StateObservable(Observable): - def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True): + def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True, include_rand=True): super(StateObservable, self).__init__() self._entities = None self._timeoutEntities = [] @@ -182,6 +182,7 @@ class StateObservable(Observable): self.enemysWhitelist = enemysWhitelist self.enemysNoBarriers = enemysNoBarriers self.rewardsTimeouts = rewardsTimeouts + self.include_rand = include_rand @property def entities(self): @@ -210,7 +211,7 @@ class StateObservable(Observable): def get_observation_space(self): return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, - shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent,), dtype=np.float32) + shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32) def get_observation(self): obs = [] @@ -232,6 +233,8 @@ class StateObservable(Observable): if self.speedAgent: obs.append(self.env.speed[0]) obs.append(self.env.speed[1]) + if self.include_rand: + obs.append(self.env.random()) self.obs = obs return np.array(obs)