Compare commits

...

4 Commits

3 changed files with 31 additions and 17 deletions

View File

@ -179,6 +179,7 @@ class ColumbusEnv(gym.Env):
self.setup()
self.entities.append(self.agent) # add it last, will be drawn on top
self._seed(self.env_seed)
self.observable._entities = None
return self.observable.get_observation()
def _draw_entities(self):
@ -338,9 +339,9 @@ class ColumbusEasierObstacles(ColumbusEnv):
self.entities.append(enemy)
class ColumbusRewardEnemyPID(ColumbusEnv):
class ColumbusJustState(ColumbusEnv):
def __init__(self, observable=observables.StateObservable(), fps=30, env_seed=None):
super(ColumbusRewardEnemyPID, self).__init__(
super(ColumbusJustState, self).__init__(
observable=observable, fps=fps)
self.aux_reward_max = 0.1
@ -359,10 +360,10 @@ class ColumbusRewardEnemyPID(ColumbusEnv):
self.entities.append(reward)
class ColumbusRewardEnemyPIDWithBarriers(ColumbusEnv):
def __init__(self, observable=observables.StateObservable(), fps=30, env_seed=3.1):
super(ColumbusRewardEnemyPIDWithBarriers, self).__init__(
observable=observable, fps=fps)
class ColumbusStateWithBarriers(ColumbusEnv):
def __init__(self, observable=observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True), fps=30, env_seed=3.1):
super(ColumbusStateWithBarriers, self).__init__(
observable=observable, fps=fps, env_seed=env_seed)
self.aux_reward_max = 0.01
self.start_pos = (0.5, 0.5)
@ -418,3 +419,9 @@ register(
entry_point=ColumbusEasyObstacles,
max_episode_steps=30*60*2,
)
register(
id='ColumbusStateWithBarriers-v0',
entry_point=ColumbusStateWithBarriers,
max_episode_steps=30*60*2,
)

View File

@ -8,7 +8,7 @@ from observables import Observable, CnnObservable
def main():
#env = ColumbusTest3_1(fps=30)
env = ColumbusEasierObstacles(fps=30)
env = ColumbusStateWithBarriers(fps=30)
env.start_pos = [0.6, 0.3]
playEnv(env)
env.close()
@ -17,7 +17,6 @@ def main():
def playEnv(env):
env.reset()
done = False
to = 0
while not done:
t1 = time()
env.render()

View File

@ -186,32 +186,34 @@ class StateObservable(Observable):
@property
def entities(self):
if self._entities:
if not self._entities == None:
return self._entities
self.env.setup()
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
rewardsWhitelist = self.rewardsWhitelist or self.env.entities
enemysWhitelist = self.enemysWhitelist or self.env.entities
self._entities = []
if self.coordsAgent:
self._entities.append(self.env.agent)
if self.coordRewards:
for entity in self.rewardsWhitelist:
for entity in rewardsWhitelist:
if isinstance(entity, entities.Reward):
self._entities.append(entity)
if self.coordsEnemys:
for entity in self.enemysWhitelist:
for entity in enemysWhitelist:
if isinstance(entity, entities.Enemy):
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
self._entities.append(entity)
if self.rewardsTimeouts:
for entity in self.enemysWhitelist:
for entity in enemysWhitelist:
if isinstance(entity, entities.TimeoutReward):
self._timeoutEntities.append(entity)
return self._entities
def get_observation_space(self):
self.env.setup()
num = len(self.entities)*2+len(self._timeoutEntities) + \
self.speedAgent + self.include_rand
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32)
shape=(num,), dtype=np.float32)
def get_observation(self):
obs = []
@ -239,4 +241,10 @@ class StateObservable(Observable):
return np.array(obs)
def draw(self):
pass
for i in range(int(len(self.obs)/2)):
x, y = self.obs[i*2], self.obs[i*2+1]
col = self.entities[i].col
pygame.draw.circle(self.env.screen, col,
(0, y*self.env.height), 1, width=0)
pygame.draw.circle(self.env.screen, col,
(x*self.env.width, 0), 1, width=0)