Implemented CompositionalObservable
This commit is contained in:
parent
cb403737f8
commit
6c4c9e0fdd
@ -431,6 +431,26 @@ class ColumbusEasierObstacles(ColumbusEnv):
|
||||
self.entities.append(enemy)
|
||||
|
||||
|
||||
class ColumbusComp(ColumbusEnv):
|
||||
def __init__(self, observable=observables.CompositionalObservable([observables.RayObservable(num_rays=6, chans=[entities.Enemy]), observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True)]), hide_map=False, fps=30, env_seed=None):
|
||||
super().__init__(
|
||||
observable=observable, fps=fps, env_seed=env_seed)
|
||||
self.draw_entities = not hide_map
|
||||
self.aux_reward_max = 10
|
||||
|
||||
def setup(self):
|
||||
self.agent.pos = self.start_pos
|
||||
for i in range(5):
|
||||
enemy = entities.CircleBarrier(self)
|
||||
enemy.radius = 30 + self.random()*70
|
||||
self.entities.append(enemy)
|
||||
for i in range(3):
|
||||
reward = entities.TeleportingReward(self)
|
||||
reward.radius = 30
|
||||
reward.reward *= 2
|
||||
self.entities.append(reward)
|
||||
|
||||
|
||||
class ColumbusJustState(ColumbusEnv):
|
||||
def __init__(self, observable=observables.StateObservable(), fps=30, num_enemies=0, num_rewards=1, env_seed=None):
|
||||
super(ColumbusJustState, self).__init__(
|
||||
@ -573,3 +593,9 @@ register(
|
||||
entry_point=ColumbusFootball,
|
||||
max_episode_steps=30*60*2,
|
||||
)
|
||||
|
||||
register(
|
||||
id='ColumbusComb-v0',
|
||||
entry_point=ColumbusComp,
|
||||
max_episode_steps=30*60*2,
|
||||
)
|
||||
|
@ -259,3 +259,43 @@ class StateObservable(Observable):
|
||||
(0, y*self.env.height+ofs[0]), 1, width=0)
|
||||
pygame.draw.circle(self.env.screen, col,
|
||||
(x*self.env.width+ofs[1], 0), 1, width=0)
|
||||
|
||||
|
||||
class CompositionalObservable(Observable):
|
||||
def __init__(self, observables):
|
||||
super().__init__()
|
||||
self.observables = observables
|
||||
|
||||
def get_observation_space(self):
|
||||
num = 0
|
||||
low = 99999
|
||||
high = -99999
|
||||
for i, obs in enumerate(self.observables):
|
||||
space = obs.get_observation_space()
|
||||
num += math.prod(space.shape)
|
||||
low = min(low, float(space.low[0]))
|
||||
high = max(high, float(space.high[0]))
|
||||
if False:
|
||||
if not i:
|
||||
low = space.low
|
||||
high = space.high
|
||||
else:
|
||||
low = np.vstack((low, space.low))
|
||||
high = np.vstack((high, space.high))
|
||||
return spaces.Box(low=low, high=high,
|
||||
shape=(num,), dtype=np.float32)
|
||||
|
||||
def get_observation(self):
|
||||
o = [float(point)
|
||||
for obs in self.observables for point in obs.get_observation()]
|
||||
o = np.array(o)
|
||||
return o
|
||||
|
||||
def draw(self):
|
||||
for obs in self.observables:
|
||||
obs.draw()
|
||||
|
||||
def _set_env(self, env):
|
||||
#self.env = env
|
||||
for obs in self.observables:
|
||||
obs._set_env(env)
|
||||
|
Loading…
Reference in New Issue
Block a user