Improved AMP/torch.compile compatibility of SimbaV2 (#21)
This commit is contained in:
parent
c354ead107
commit
83907422a3
@ -4,10 +4,6 @@ from mujoco_playground import wrapper_torch
|
||||
import jax
|
||||
import mujoco
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
|
||||
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
|
||||
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
|
||||
|
||||
|
||||
class PlaygroundEvalEnvWrapper:
|
||||
def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed):
|
||||
|
@ -28,7 +28,7 @@ class Scaler(nn.Module):
|
||||
self.forward_scaler = init / scale
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.scaler * self.forward_scaler * x
|
||||
return self.scaler.to(x.dtype) * self.forward_scaler * x
|
||||
|
||||
|
||||
class HyperDense(nn.Module):
|
||||
@ -97,7 +97,9 @@ class HyperEmbedder(nn.Module):
|
||||
self.c_shift = c_shift
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_axis = torch.full((*x.shape[:-1], 1), self.c_shift, device=x.device)
|
||||
new_axis = torch.full(
|
||||
(*x.shape[:-1], 1), self.c_shift, device=x.device, dtype=x.dtype
|
||||
)
|
||||
x = torch.cat([x, new_axis], dim=-1)
|
||||
x = l2normalize(x, axis=-1)
|
||||
x = self.w(x)
|
||||
@ -170,7 +172,7 @@ class HyperTanhPolicy(nn.Module):
|
||||
# Mean path
|
||||
mean = self.mean_w1(x)
|
||||
mean = self.mean_scaler(mean)
|
||||
mean = self.mean_w2(mean) + self.mean_bias
|
||||
mean = self.mean_w2(mean) + self.mean_bias.to(mean.dtype)
|
||||
mean = torch.tanh(mean)
|
||||
return mean
|
||||
|
||||
@ -197,7 +199,7 @@ class HyperCategoricalValue(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.w1(x)
|
||||
logits = self.scaler(logits)
|
||||
logits = self.w2(logits) + self.bias
|
||||
logits = self.w2(logits) + self.bias.to(logits.dtype)
|
||||
return logits
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user