Added a dummy Observable
This commit is contained in:
		
							parent
							
								
									f94eaa5dc0
								
							
						
					
					
						commit
						ff4e81d4f1
					
				@ -35,6 +35,9 @@ def parseObs(obsConf):
 | 
			
		||||
    elif obsConf['type'] == 'CNN':
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type']}
 | 
			
		||||
        return observables.CnnObservable(**conf)
 | 
			
		||||
    elif obsConf['type'] == 'Dummy':
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type']}
 | 
			
		||||
        return observables.Observable(**conf)
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception('Unknown Observable selected')
 | 
			
		||||
 | 
			
		||||
@ -84,6 +87,7 @@ class ColumbusEnv(gym.Env):
 | 
			
		||||
        self.void_barrier = void_is_type_barrier
 | 
			
		||||
        self.void_damage = void_damage
 | 
			
		||||
        self.torus_topology = torus_topology
 | 
			
		||||
        self.default_collision_elasticity = 1
 | 
			
		||||
 | 
			
		||||
        self.paused = False
 | 
			
		||||
        self.keypress_timeout = 0
 | 
			
		||||
 | 
			
		||||
@ -40,6 +40,7 @@ def chooseEnv():
 | 
			
		||||
 | 
			
		||||
def playEnv(env):
 | 
			
		||||
    done = False
 | 
			
		||||
    env.reset()
 | 
			
		||||
    while not done:
 | 
			
		||||
        t1 = time()
 | 
			
		||||
        env.render()
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@ class Observable():
 | 
			
		||||
    def _set_env(self, env):
 | 
			
		||||
        self.env = env
 | 
			
		||||
 | 
			
		||||
    def get_observation_space():
 | 
			
		||||
    def get_observation_space(self):
 | 
			
		||||
        print("[!] Using dummyObservable. Env won't output anything")
 | 
			
		||||
        return spaces.Box(low=0, high=1,
 | 
			
		||||
                          shape=(1,), dtype=np.float32)
 | 
			
		||||
@ -223,7 +223,7 @@ class StateObservable(Observable):
 | 
			
		||||
        self._entities = None
 | 
			
		||||
 | 
			
		||||
    def get_observation_space(self):
 | 
			
		||||
        self.env.reset()
 | 
			
		||||
        self.reset()
 | 
			
		||||
        num = len(self.entities)*2+len(self._timeoutEntities) + \
 | 
			
		||||
            self.speedAgent*2 + self.include_rand
 | 
			
		||||
        return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
 | 
			
		||||
@ -366,7 +366,6 @@ class CompositionalObservable(Observable):
 | 
			
		||||
            obs.draw()
 | 
			
		||||
 | 
			
		||||
    def _set_env(self, env):
 | 
			
		||||
        # self.env = env
 | 
			
		||||
        for obs in self.observables:
 | 
			
		||||
            obs._set_env(env)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user