Always return np.array not th.Tensor
This commit is contained in:
parent
a41f93beed
commit
0b71d2fe0c
@ -337,8 +337,6 @@ class CompositionalObservable(Observable):
|
|||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
num = 0
|
num = 0
|
||||||
low = 99999
|
|
||||||
high = -99999
|
|
||||||
for i, obs in enumerate(self.observables):
|
for i, obs in enumerate(self.observables):
|
||||||
space = obs.get_observation_space()
|
space = obs.get_observation_space()
|
||||||
num += math.prod(space.shape)
|
num += math.prod(space.shape)
|
||||||
@ -352,9 +350,11 @@ class CompositionalObservable(Observable):
|
|||||||
shape=(num,), dtype=np.float32)
|
shape=(num,), dtype=np.float32)
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
o = [th.reshape(th.Tensor(obs.get_observation()), (-1,))
|
o = [obs.get_observation().reshape((-1))
|
||||||
for obs in self.observables]
|
for obs in self.observables]
|
||||||
o = th.hstack(o)
|
o = np.hstack(o)
|
||||||
|
import pdb
|
||||||
|
pdb.set_trace()
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def draw(self):
|
def draw(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user