New Observable: Compass
This commit is contained in:
		
							parent
							
								
									f18310ed5b
								
							
						
					
					
						commit
						8706462358
					
				@ -10,6 +10,35 @@ from columbus import entities, observables
 | 
			
		||||
import torch as th
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parseObs(obsConf):
 | 
			
		||||
    if type(obsConf) == list:
 | 
			
		||||
        obs = []
 | 
			
		||||
        for i, c in enumerate(obsConf):
 | 
			
		||||
            obs.append(parseObs(c))
 | 
			
		||||
        if len(obs) == 1:
 | 
			
		||||
            return obs[0]
 | 
			
		||||
        else:
 | 
			
		||||
            return observables.CompositionalObservable(obs)
 | 
			
		||||
 | 
			
		||||
    if obsConf['type'] == 'State':
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type']}
 | 
			
		||||
        return observables.StateObservable(conf)
 | 
			
		||||
    elif obsConf['type'] == 'Compass':
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type']}
 | 
			
		||||
        return observables.CompassObservable(**conf)
 | 
			
		||||
    elif obsConf['type'] == 'RayCast':
 | 
			
		||||
        chans = []
 | 
			
		||||
        for chan in obsConf.get('chans', []):
 | 
			
		||||
            chans.append(getattr(entities, chan))
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']}
 | 
			
		||||
        return observables.RayObservable(chans=chans, **conf)
 | 
			
		||||
    elif obsConf['type'] == 'CNN':
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type']}
 | 
			
		||||
        return observables.CnnObservable(**conf)
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception('Unknown Observable selected')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColumbusEnv(gym.Env):
 | 
			
		||||
    metadata = {'render.modes': ['human']}
 | 
			
		||||
 | 
			
		||||
@ -527,6 +556,30 @@ class ColumbusStateWithBarriers(ColumbusEnv):
 | 
			
		||||
            self.entities.append(reward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColumbusCompassWithBarriers(ColumbusEnv):
 | 
			
		||||
    def __init__(self, observable=observables.CompassObservable(coordsRewards=True), fps=30, env_seed=3.141, num_enemys=0, num_barriers=3, aux_reward_max=10, **kw):
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            observable=observable,  fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
 | 
			
		||||
        self.start_pos = (0.5, 0.5)
 | 
			
		||||
        self.num_barriers = num_barriers
 | 
			
		||||
        self.num_enemys = num_enemys
 | 
			
		||||
 | 
			
		||||
    def setup(self):
 | 
			
		||||
        self.agent.pos = self.start_pos
 | 
			
		||||
        for i in range(self.num_barriers):
 | 
			
		||||
            enemy = entities.CircleBarrier(self)
 | 
			
		||||
            enemy.radius = self.random()*25+75
 | 
			
		||||
            self.entities.append(enemy)
 | 
			
		||||
        for i in range(self.num_enemys):
 | 
			
		||||
            enemy = entities.FlyingChaser(self)
 | 
			
		||||
            enemy.chase_acc = 0.55  # *0.6+0.5
 | 
			
		||||
            self.entities.append(enemy)
 | 
			
		||||
        for i in range(1):
 | 
			
		||||
            reward = entities.TeleportingReward(self)
 | 
			
		||||
            reward.radius = 30
 | 
			
		||||
            self.entities.append(reward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColumbusTrivialRay(ColumbusStateWithBarriers):
 | 
			
		||||
    def __init__(self, observable=observables.RayObservable(num_rays=8, ray_len=512), hide_map=False, fps=30, **kw):
 | 
			
		||||
        super(ColumbusTrivialRay, self).__init__(
 | 
			
		||||
@ -558,32 +611,6 @@ class ColumbusFootball(ColumbusEnv):
 | 
			
		||||
            self.entities.append(entities.FlyingFootballPlayer(self, ball))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parseObs(obsConf):
 | 
			
		||||
    if type(obsConf) == list:
 | 
			
		||||
        obs = []
 | 
			
		||||
        for i, c in enumerate(obsConf):
 | 
			
		||||
            obs.append(parseObs(c))
 | 
			
		||||
        if len(obs) == 1:
 | 
			
		||||
            return obs[0]
 | 
			
		||||
        else:
 | 
			
		||||
            return observables.CompositionalObservable(obs)
 | 
			
		||||
 | 
			
		||||
    if obsConf['type'] == 'State':
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type']}
 | 
			
		||||
        return observables.StateObservable(conf)
 | 
			
		||||
    elif obsConf['type'] == 'RayCast':
 | 
			
		||||
        chans = []
 | 
			
		||||
        for chan in obsConf.get('chans', []):
 | 
			
		||||
            chans.append(getattr(entities, chan))
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']}
 | 
			
		||||
        return observables.RayObservable(chans=chans, **conf)
 | 
			
		||||
    elif obsConf['type'] == 'CNN':
 | 
			
		||||
        conf = {k: v for k, v in obsConf.items() if k not in ['type']}
 | 
			
		||||
        return observables.CnnObservable(**conf)
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception('Unknown Observable selected')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColumbusConfigDefined(ColumbusEnv):
 | 
			
		||||
    def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
 | 
			
		||||
        super().__init__(
 | 
			
		||||
@ -673,6 +700,12 @@ register(
 | 
			
		||||
    max_episode_steps=30*60*2,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
register(
 | 
			
		||||
    id='ColumbusCompassWithBarriers-v0',
 | 
			
		||||
    entry_point=ColumbusCompassWithBarriers,
 | 
			
		||||
    max_episode_steps=30*60*2,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
register(
 | 
			
		||||
    id='ColumbusTrivialRay-v0',
 | 
			
		||||
    entry_point=ColumbusTrivialRay,
 | 
			
		||||
 | 
			
		||||
@ -265,6 +265,71 @@ class StateObservable(Observable):
 | 
			
		||||
                               (x*self.env.width+ofs[1], 0), 1, width=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CompassObservable(Observable):
 | 
			
		||||
    def __init__(self, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._entities = None
 | 
			
		||||
        self._timeoutEntities = []
 | 
			
		||||
        self.coordRewards = coordsRewards
 | 
			
		||||
        self.rewardsWhitelist = rewardsWhitelist
 | 
			
		||||
        self.coordsEnemys = coordsEnemys
 | 
			
		||||
        self.enemysWhitelist = enemysWhitelist
 | 
			
		||||
        self.enemysNoBarriers = enemysNoBarriers
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def entities(self):
 | 
			
		||||
        if not self._entities == None:
 | 
			
		||||
            return self._entities
 | 
			
		||||
        rewardsWhitelist = self.rewardsWhitelist or self.env.entities
 | 
			
		||||
        enemysWhitelist = self.enemysWhitelist or self.env.entities
 | 
			
		||||
        self._entities = []
 | 
			
		||||
        if self.coordRewards:
 | 
			
		||||
            for entity in rewardsWhitelist:
 | 
			
		||||
                if isinstance(entity, entities.Reward):
 | 
			
		||||
                    self._entities.append(entity)
 | 
			
		||||
        if self.coordsEnemys:
 | 
			
		||||
            for entity in enemysWhitelist:
 | 
			
		||||
                if isinstance(entity, entities.Enemy):
 | 
			
		||||
                    if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
 | 
			
		||||
                        self._entities.append(entity)
 | 
			
		||||
        return self._entities
 | 
			
		||||
 | 
			
		||||
    def get_observation_space(self):
 | 
			
		||||
        self.env.reset()
 | 
			
		||||
        num = len(self.entities)*2
 | 
			
		||||
        return spaces.Box(low=-1, high=1,
 | 
			
		||||
                          shape=(num,), dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
    def get_observation(self):
 | 
			
		||||
        obs = []
 | 
			
		||||
        for entity in self.entities:
 | 
			
		||||
            dx, dy = entity.pos[0] - \
 | 
			
		||||
                self.env.agent.pos[0], entity.pos[1] - self.env.agent.pos[1]
 | 
			
		||||
            l = math.sqrt(dx**2 + dy**2)*2
 | 
			
		||||
            x, y = math.tanh(dx/l), math.tanh(dy/l)
 | 
			
		||||
            obs.append(x)
 | 
			
		||||
            obs.append(y)
 | 
			
		||||
 | 
			
		||||
        self.obs = obs
 | 
			
		||||
        return np.array(obs)
 | 
			
		||||
 | 
			
		||||
    def draw(self):
 | 
			
		||||
        ofs = (0 + self.env.height/2,
 | 
			
		||||
               0 + self.env.width/2)
 | 
			
		||||
        if True:
 | 
			
		||||
            pygame.draw.circle(self.env.screen, self.env.agent.col,
 | 
			
		||||
                               (0, self.env.height/2), 3, width=0)
 | 
			
		||||
            pygame.draw.circle(self.env.screen, self.env.agent.col,
 | 
			
		||||
                               (self.env.width/2, 0), 3, width=0)
 | 
			
		||||
        for i in range(int(len(self.obs)/2)):
 | 
			
		||||
            x, y = self.obs[i*2], self.obs[i*2+1]
 | 
			
		||||
            col = self.entities[i].col
 | 
			
		||||
            pygame.draw.circle(self.env.screen, col,
 | 
			
		||||
                               (0, y*self.env.height+ofs[0]), 1, width=0)
 | 
			
		||||
            pygame.draw.circle(self.env.screen, col,
 | 
			
		||||
                               (x*self.env.width+ofs[1], 0), 1, width=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CompositionalObservable(Observable):
 | 
			
		||||
    def __init__(self, observables):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user