Fixed Bug RayObservable did not detect lower Void
This commit is contained in:
parent
b668cf5746
commit
b0aeb94cd7
@ -109,7 +109,7 @@ class RayObservable(Observable):
|
|||||||
for entity in entities_l:
|
for entity in entities_l:
|
||||||
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
|
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
|
||||||
if isinstance(entity, entities.Void):
|
if isinstance(entity, entities.Void):
|
||||||
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height:
|
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
if entity.shape != 'circle':
|
if entity.shape != 'circle':
|
||||||
@ -169,7 +169,7 @@ class RayObservable(Observable):
|
|||||||
|
|
||||||
|
|
||||||
class StateObservable(Observable):
|
class StateObservable(Observable):
|
||||||
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
|
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True, include_rand=True):
|
||||||
super(StateObservable, self).__init__()
|
super(StateObservable, self).__init__()
|
||||||
self._entities = None
|
self._entities = None
|
||||||
self._timeoutEntities = []
|
self._timeoutEntities = []
|
||||||
@ -182,6 +182,7 @@ class StateObservable(Observable):
|
|||||||
self.enemysWhitelist = enemysWhitelist
|
self.enemysWhitelist = enemysWhitelist
|
||||||
self.enemysNoBarriers = enemysNoBarriers
|
self.enemysNoBarriers = enemysNoBarriers
|
||||||
self.rewardsTimeouts = rewardsTimeouts
|
self.rewardsTimeouts = rewardsTimeouts
|
||||||
|
self.include_rand = include_rand
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entities(self):
|
def entities(self):
|
||||||
@ -210,7 +211,7 @@ class StateObservable(Observable):
|
|||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent,), dtype=np.float32)
|
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32)
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
obs = []
|
obs = []
|
||||||
@ -232,6 +233,8 @@ class StateObservable(Observable):
|
|||||||
if self.speedAgent:
|
if self.speedAgent:
|
||||||
obs.append(self.env.speed[0])
|
obs.append(self.env.speed[0])
|
||||||
obs.append(self.env.speed[1])
|
obs.append(self.env.speed[1])
|
||||||
|
if self.include_rand:
|
||||||
|
obs.append(self.env.random())
|
||||||
self.obs = obs
|
self.obs = obs
|
||||||
return np.array(obs)
|
return np.array(obs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user