Compare commits
No commits in common. "2132deedcd835fa1351115619ba0ecd82a77e8ba" and "519c57bb6477def39413f5f814268002a8b9a79f" have entirely different histories.
2132deedcd
...
519c57bb64
@ -179,7 +179,6 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.setup()
|
self.setup()
|
||||||
self.entities.append(self.agent) # add it last, will be drawn on top
|
self.entities.append(self.agent) # add it last, will be drawn on top
|
||||||
self._seed(self.env_seed)
|
self._seed(self.env_seed)
|
||||||
self.observable._entities = None
|
|
||||||
return self.observable.get_observation()
|
return self.observable.get_observation()
|
||||||
|
|
||||||
def _draw_entities(self):
|
def _draw_entities(self):
|
||||||
@ -339,9 +338,9 @@ class ColumbusEasierObstacles(ColumbusEnv):
|
|||||||
self.entities.append(enemy)
|
self.entities.append(enemy)
|
||||||
|
|
||||||
|
|
||||||
class ColumbusJustState(ColumbusEnv):
|
class ColumbusRewardEnemyPID(ColumbusEnv):
|
||||||
def __init__(self, observable=observables.StateObservable(), fps=30, env_seed=None):
|
def __init__(self, observable=observables.StateObservable(), fps=30, env_seed=None):
|
||||||
super(ColumbusJustState, self).__init__(
|
super(ColumbusRewardEnemyPID, self).__init__(
|
||||||
observable=observable, fps=fps)
|
observable=observable, fps=fps)
|
||||||
self.aux_reward_max = 0.1
|
self.aux_reward_max = 0.1
|
||||||
|
|
||||||
@ -360,10 +359,10 @@ class ColumbusJustState(ColumbusEnv):
|
|||||||
self.entities.append(reward)
|
self.entities.append(reward)
|
||||||
|
|
||||||
|
|
||||||
class ColumbusStateWithBarriers(ColumbusEnv):
|
class ColumbusRewardEnemyPIDWithBarriers(ColumbusEnv):
|
||||||
def __init__(self, observable=observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True), fps=30, env_seed=3.1):
|
def __init__(self, observable=observables.StateObservable(), fps=30, env_seed=3.1):
|
||||||
super(ColumbusStateWithBarriers, self).__init__(
|
super(ColumbusRewardEnemyPIDWithBarriers, self).__init__(
|
||||||
observable=observable, fps=fps, env_seed=env_seed)
|
observable=observable, fps=fps)
|
||||||
self.aux_reward_max = 0.01
|
self.aux_reward_max = 0.01
|
||||||
self.start_pos = (0.5, 0.5)
|
self.start_pos = (0.5, 0.5)
|
||||||
|
|
||||||
@ -419,9 +418,3 @@ register(
|
|||||||
entry_point=ColumbusEasyObstacles,
|
entry_point=ColumbusEasyObstacles,
|
||||||
max_episode_steps=30*60*2,
|
max_episode_steps=30*60*2,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
|
||||||
id='ColumbusStateWithBarriers-v0',
|
|
||||||
entry_point=ColumbusStateWithBarriers,
|
|
||||||
max_episode_steps=30*60*2,
|
|
||||||
)
|
|
||||||
|
@ -8,7 +8,7 @@ from observables import Observable, CnnObservable
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
#env = ColumbusTest3_1(fps=30)
|
#env = ColumbusTest3_1(fps=30)
|
||||||
env = ColumbusStateWithBarriers(fps=30)
|
env = ColumbusEasierObstacles(fps=30)
|
||||||
env.start_pos = [0.6, 0.3]
|
env.start_pos = [0.6, 0.3]
|
||||||
playEnv(env)
|
playEnv(env)
|
||||||
env.close()
|
env.close()
|
||||||
@ -17,6 +17,7 @@ def main():
|
|||||||
def playEnv(env):
|
def playEnv(env):
|
||||||
env.reset()
|
env.reset()
|
||||||
done = False
|
done = False
|
||||||
|
to = 0
|
||||||
while not done:
|
while not done:
|
||||||
t1 = time()
|
t1 = time()
|
||||||
env.render()
|
env.render()
|
||||||
|
@ -186,34 +186,32 @@ class StateObservable(Observable):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def entities(self):
|
def entities(self):
|
||||||
if not self._entities == None:
|
if self._entities:
|
||||||
return self._entities
|
return self._entities
|
||||||
rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
self.env.setup()
|
||||||
enemysWhitelist = self.enemysWhitelist or self.env.entities
|
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
||||||
|
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
|
||||||
self._entities = []
|
self._entities = []
|
||||||
if self.coordsAgent:
|
if self.coordsAgent:
|
||||||
self._entities.append(self.env.agent)
|
self._entities.append(self.env.agent)
|
||||||
if self.coordRewards:
|
if self.coordRewards:
|
||||||
for entity in rewardsWhitelist:
|
for entity in self.rewardsWhitelist:
|
||||||
if isinstance(entity, entities.Reward):
|
if isinstance(entity, entities.Reward):
|
||||||
self._entities.append(entity)
|
self._entities.append(entity)
|
||||||
if self.coordsEnemys:
|
if self.coordsEnemys:
|
||||||
for entity in enemysWhitelist:
|
for entity in self.enemysWhitelist:
|
||||||
if isinstance(entity, entities.Enemy):
|
if isinstance(entity, entities.Enemy):
|
||||||
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
||||||
self._entities.append(entity)
|
self._entities.append(entity)
|
||||||
if self.rewardsTimeouts:
|
if self.rewardsTimeouts:
|
||||||
for entity in enemysWhitelist:
|
for entity in self.enemysWhitelist:
|
||||||
if isinstance(entity, entities.TimeoutReward):
|
if isinstance(entity, entities.TimeoutReward):
|
||||||
self._timeoutEntities.append(entity)
|
self._timeoutEntities.append(entity)
|
||||||
return self._entities
|
return self._entities
|
||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
self.env.setup()
|
|
||||||
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
|
||||||
self.speedAgent + self.include_rand
|
|
||||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
shape=(num,), dtype=np.float32)
|
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32)
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
obs = []
|
obs = []
|
||||||
@ -241,10 +239,4 @@ class StateObservable(Observable):
|
|||||||
return np.array(obs)
|
return np.array(obs)
|
||||||
|
|
||||||
def draw(self):
|
def draw(self):
|
||||||
for i in range(int(len(self.obs)/2)):
|
pass
|
||||||
x, y = self.obs[i*2], self.obs[i*2+1]
|
|
||||||
col = self.entities[i].col
|
|
||||||
pygame.draw.circle(self.env.screen, col,
|
|
||||||
(0, y*self.env.height), 1, width=0)
|
|
||||||
pygame.draw.circle(self.env.screen, col,
|
|
||||||
(x*self.env.width, 0), 1, width=0)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user