Added a dummy Observable
This commit is contained in:
parent
f94eaa5dc0
commit
ff4e81d4f1
@ -35,6 +35,9 @@ def parseObs(obsConf):
|
|||||||
elif obsConf['type'] == 'CNN':
|
elif obsConf['type'] == 'CNN':
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
return observables.CnnObservable(**conf)
|
return observables.CnnObservable(**conf)
|
||||||
|
elif obsConf['type'] == 'Dummy':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.Observable(**conf)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown Observable selected')
|
raise Exception('Unknown Observable selected')
|
||||||
|
|
||||||
@ -84,6 +87,7 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.void_barrier = void_is_type_barrier
|
self.void_barrier = void_is_type_barrier
|
||||||
self.void_damage = void_damage
|
self.void_damage = void_damage
|
||||||
self.torus_topology = torus_topology
|
self.torus_topology = torus_topology
|
||||||
|
self.default_collision_elasticity = 1
|
||||||
|
|
||||||
self.paused = False
|
self.paused = False
|
||||||
self.keypress_timeout = 0
|
self.keypress_timeout = 0
|
||||||
|
@ -40,6 +40,7 @@ def chooseEnv():
|
|||||||
|
|
||||||
def playEnv(env):
|
def playEnv(env):
|
||||||
done = False
|
done = False
|
||||||
|
env.reset()
|
||||||
while not done:
|
while not done:
|
||||||
t1 = time()
|
t1 = time()
|
||||||
env.render()
|
env.render()
|
||||||
|
@ -13,7 +13,7 @@ class Observable():
|
|||||||
def _set_env(self, env):
|
def _set_env(self, env):
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
def get_observation_space():
|
def get_observation_space(self):
|
||||||
print("[!] Using dummyObservable. Env won't output anything")
|
print("[!] Using dummyObservable. Env won't output anything")
|
||||||
return spaces.Box(low=0, high=1,
|
return spaces.Box(low=0, high=1,
|
||||||
shape=(1,), dtype=np.float32)
|
shape=(1,), dtype=np.float32)
|
||||||
@ -223,7 +223,7 @@ class StateObservable(Observable):
|
|||||||
self._entities = None
|
self._entities = None
|
||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
self.env.reset()
|
self.reset()
|
||||||
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
||||||
self.speedAgent*2 + self.include_rand
|
self.speedAgent*2 + self.include_rand
|
||||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
@ -366,7 +366,6 @@ class CompositionalObservable(Observable):
|
|||||||
obs.draw()
|
obs.draw()
|
||||||
|
|
||||||
def _set_env(self, env):
|
def _set_env(self, env):
|
||||||
# self.env = env
|
|
||||||
for obs in self.observables:
|
for obs in self.observables:
|
||||||
obs._set_env(env)
|
obs._set_env(env)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user