Compare commits
No commits in common. "519c57bb6477def39413f5f814268002a8b9a79f" and "b668cf5746a2c2ff17707e23b6c2821f7f54da7a" have entirely different histories.
519c57bb64
...
b668cf5746
@ -195,7 +195,6 @@ class TimeoutReward(OnceReward):
|
|||||||
self.env.timers.append((self.timeout, self.set_avaible, True))
|
self.env.timers.append((self.timeout, self.set_avaible, True))
|
||||||
|
|
||||||
|
|
||||||
# Not a real entity. Is used in the config of RayObserver to reference the outer boundary of the environment.
|
|
||||||
class Void():
|
class Void():
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
self.col = (50, 50, 50)
|
self.col = (50, 50, 50)
|
||||||
|
@ -109,7 +109,7 @@ class RayObservable(Observable):
|
|||||||
for entity in entities_l:
|
for entity in entities_l:
|
||||||
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
|
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
|
||||||
if isinstance(entity, entities.Void):
|
if isinstance(entity, entities.Void):
|
||||||
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width:
|
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
if entity.shape != 'circle':
|
if entity.shape != 'circle':
|
||||||
@ -169,7 +169,7 @@ class RayObservable(Observable):
|
|||||||
|
|
||||||
|
|
||||||
class StateObservable(Observable):
|
class StateObservable(Observable):
|
||||||
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True, include_rand=True):
|
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
|
||||||
super(StateObservable, self).__init__()
|
super(StateObservable, self).__init__()
|
||||||
self._entities = None
|
self._entities = None
|
||||||
self._timeoutEntities = []
|
self._timeoutEntities = []
|
||||||
@ -182,7 +182,6 @@ class StateObservable(Observable):
|
|||||||
self.enemysWhitelist = enemysWhitelist
|
self.enemysWhitelist = enemysWhitelist
|
||||||
self.enemysNoBarriers = enemysNoBarriers
|
self.enemysNoBarriers = enemysNoBarriers
|
||||||
self.rewardsTimeouts = rewardsTimeouts
|
self.rewardsTimeouts = rewardsTimeouts
|
||||||
self.include_rand = include_rand
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entities(self):
|
def entities(self):
|
||||||
@ -211,7 +210,7 @@ class StateObservable(Observable):
|
|||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32)
|
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent,), dtype=np.float32)
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
obs = []
|
obs = []
|
||||||
@ -233,8 +232,6 @@ class StateObservable(Observable):
|
|||||||
if self.speedAgent:
|
if self.speedAgent:
|
||||||
obs.append(self.env.speed[0])
|
obs.append(self.env.speed[0])
|
||||||
obs.append(self.env.speed[1])
|
obs.append(self.env.speed[1])
|
||||||
if self.include_rand:
|
|
||||||
obs.append(self.env.random())
|
|
||||||
self.obs = obs
|
self.obs = obs
|
||||||
return np.array(obs)
|
return np.array(obs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user