diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py
index b96e628..aae1d74 100644
--- a/metastable_baselines/misc/rollout_buffer.py
+++ b/metastable_baselines/misc/rollout_buffer.py
@@ -280,4 +280,4 @@ class GaussianRolloutCollectorAuxclass():
 
     def get_past_trajectories(self) -> th.Tensor:
         # TODO: Respect Episode Boundaries
-        return th.Tensor(self.rollout_buffer.actions)
+        return np.swapaxes(th.Tensor(self.rollout_buffer.actions[:self.rollout_buffer.pos]), 0, 1)
diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py
index 94d13d0..bba4f36 100644
--- a/metastable_baselines/ppo/policies.py
+++ b/metastable_baselines/ppo/policies.py
@@ -359,7 +359,7 @@ class ActorCriticPolicy(BasePolicy):
         else:
             raise ValueError("Invalid action distribution")
 
-    def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
+    def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
         """
         Get the action according to the policy for a given observation.
 
@@ -367,7 +367,10 @@ class ActorCriticPolicy(BasePolicy):
         :param deterministic: Whether to use stochastic or deterministic actions
         :return: Taken action according to the policy
         """
-        return self.get_distribution(observation).get_actions(deterministic=deterministic)
+        if self.use_pca:
+            return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
+        else:
+            return self.get_distribution(observation).get_actions(deterministic=deterministic)
 
     def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
         """
@@ -409,6 +412,57 @@ class ActorCriticPolicy(BasePolicy):
         latent_vf = self.mlp_extractor.forward_critic(features)
         return self.value_net(latent_vf)
 
+    def predict(
+        self,
+        observation: Union[np.ndarray, Dict[str, np.ndarray]],
+        state: Optional[Tuple[np.ndarray, ...]] = None,
+        episode_start: Optional[np.ndarray] = None,
+        deterministic: bool = False,
+        trajectory: th.Tensor = None,
+    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
+        """
+        Get the policy action from an observation (and optional hidden state).
+        Includes sugar-coating to handle different observations (e.g. normalizing images).
+
+        :param observation: the input observation
+        :param state: The last hidden states (can be None, used in recurrent policies)
+        :param episode_start: The last masks (can be None, used in recurrent policies)
+            this correspond to beginning of episodes,
+            where the hidden states of the RNN must be reset.
+        :param deterministic: Whether or not to return deterministic actions.
+        :return: the model's action and the next hidden state
+            (used in recurrent policies)
+        """
+        # TODO (GH/1): add support for RNN policies
+        # if state is None:
+        #     state = self.initial_state
+        # if episode_start is None:
+        #     episode_start = [False for _ in range(self.n_envs)]
+        # Switch to eval mode (this affects batch norm / dropout)
+        self.set_training_mode(False)
+
+        observation, vectorized_env = self.obs_to_tensor(observation)
+
+        with th.no_grad():
+            actions = self._predict(observation, deterministic=deterministic, trajectory=trajectory)
+        # Convert to numpy, and reshape to the original action shape
+        actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
+
+        if isinstance(self.action_space, gym.spaces.Box):
+            if self.squash_output:
+                # Rescale to proper domain when using squashing
+                actions = self.unscale_action(actions)
+            else:
+                # Actions could be on arbitrary scale, so clip the actions to avoid
+                # out of bound error (e.g. if sampling from a Gaussian distribution)
+                actions = np.clip(actions, self.action_space.low, self.action_space.high)
+
+        # Remove batch dimension if needed
+        if not vectorized_env:
+            actions = actions.squeeze(axis=0)
+
+        return actions, state
+
 
 class ActorCriticCnnPolicy(ActorCriticPolicy):
     """
diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py
index bfc2844..abd34cf 100644
--- a/metastable_baselines/ppo/ppo.py
+++ b/metastable_baselines/ppo/ppo.py
@@ -1,5 +1,5 @@
 import warnings
-from typing import Any, Dict, Optional, Type, Union, NamedTuple
+from typing import Any, Dict, Optional, Type, Union, NamedTuple, Tuple
 
 import numpy as np
 import torch as th
@@ -422,6 +422,32 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
         if self.clip_range_vf is not None:
             self.logger.record("train/clip_range_vf", clip_range_vf)
 
+    def predict(
+        self,
+        observation: Union[np.ndarray, Dict[str, np.ndarray]],
+        state: Optional[Tuple[np.ndarray, ...]] = None,
+        episode_start: Optional[np.ndarray] = None,
+        deterministic: bool = False,
+        trajectory: th.Tensor = None,
+    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
+        """
+        Get the policy action from an observation (and optional hidden state).
+        Includes sugar-coating to handle different observations (e.g. normalizing images).
+
+        :param observation: the input observation
+        :param state: The last hidden states (can be None, used in recurrent policies)
+        :param episode_start: The last masks (can be None, used in recurrent policies)
+            this correspond to beginning of episodes,
+            where the hidden states of the RNN must be reset.
+        :param deterministic: Whether or not to return deterministic actions.
+        :return: the model's action and the next hidden state
+            (used in recurrent policies)
+        """
+        if self.policy.use_pca:
+            return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
+        else:
+            return self.policy.predict(observation, state, episode_start, deterministic)
+
     def learn(
         self,
         total_timesteps: int,