Compare commits
2 Commits
b668cf5746
...
519c57bb64
Author | SHA1 | Date | |
---|---|---|---|
519c57bb64 | |||
b0aeb94cd7 |
@ -195,6 +195,7 @@ class TimeoutReward(OnceReward):
|
||||
self.env.timers.append((self.timeout, self.set_avaible, True))
|
||||
|
||||
|
||||
# Not a real entity. Is used in the config of RayObserver to reference the outer boundary of the environment.
|
||||
class Void():
|
||||
def __init__(self, env):
|
||||
self.col = (50, 50, 50)
|
||||
|
@ -109,7 +109,7 @@ class RayObservable(Observable):
|
||||
for entity in entities_l:
|
||||
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
|
||||
if isinstance(entity, entities.Void):
|
||||
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height:
|
||||
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width:
|
||||
return True
|
||||
else:
|
||||
if entity.shape != 'circle':
|
||||
@ -169,7 +169,7 @@ class RayObservable(Observable):
|
||||
|
||||
|
||||
class StateObservable(Observable):
|
||||
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
|
||||
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True, include_rand=True):
|
||||
super(StateObservable, self).__init__()
|
||||
self._entities = None
|
||||
self._timeoutEntities = []
|
||||
@ -182,6 +182,7 @@ class StateObservable(Observable):
|
||||
self.enemysWhitelist = enemysWhitelist
|
||||
self.enemysNoBarriers = enemysNoBarriers
|
||||
self.rewardsTimeouts = rewardsTimeouts
|
||||
self.include_rand = include_rand
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
@ -210,7 +211,7 @@ class StateObservable(Observable):
|
||||
|
||||
def get_observation_space(self):
|
||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent,), dtype=np.float32)
|
||||
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32)
|
||||
|
||||
def get_observation(self):
|
||||
obs = []
|
||||
@ -232,6 +233,8 @@ class StateObservable(Observable):
|
||||
if self.speedAgent:
|
||||
obs.append(self.env.speed[0])
|
||||
obs.append(self.env.speed[1])
|
||||
if self.include_rand:
|
||||
obs.append(self.env.random())
|
||||
self.obs = obs
|
||||
return np.array(obs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user