diff --git a/env.py b/env.py index df11fc6..e8aec68 100644 --- a/env.py +++ b/env.py @@ -200,3 +200,8 @@ class ColumbusEnv(gym.Env): def close(self): pygame.display.quit() pygame.quit() + + +class ColumbusTest3_1(ColumbusEnv): + def __init__(self): + super(ColumbusEnv, self).__init__(observables.CnnObservable()) diff --git a/observables.py b/observables.py index 134dd37..4a871e7 100644 --- a/observables.py +++ b/observables.py @@ -8,6 +8,9 @@ class Observable(): self.obs = None pass + def _set_env(self, env): + self.env = env + def get_observation_space(): print("[!] Using dummyObservable. Env won't output anything") return spaces.Box(low=0, high=255, @@ -28,9 +31,6 @@ class CnnObservable(Observable): else: self.scaler = pygame.transform.scale - def _set_env(self, env): - self.env = env - def get_observation_space(self): return spaces.Box(low=0, high=255, shape=(self.out_width, self.out_height), dtype=np.uint8)