Improved AMP/torch.compile compatibility of SimbaV2 (#21)

This commit is contained in:
Younggyo Seo 2025-07-07 10:04:46 -07:00 committed by GitHub
parent c354ead107
commit 83907422a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 8 deletions

View File

@ -4,10 +4,6 @@ from mujoco_playground import wrapper_torch
import jax
import mujoco
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
class PlaygroundEvalEnvWrapper:
def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed):

View File

@ -28,7 +28,7 @@ class Scaler(nn.Module):
self.forward_scaler = init / scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.scaler * self.forward_scaler * x
return self.scaler.to(x.dtype) * self.forward_scaler * x
class HyperDense(nn.Module):
@ -97,7 +97,9 @@ class HyperEmbedder(nn.Module):
self.c_shift = c_shift
def forward(self, x: torch.Tensor) -> torch.Tensor:
new_axis = torch.full((*x.shape[:-1], 1), self.c_shift, device=x.device)
new_axis = torch.full(
(*x.shape[:-1], 1), self.c_shift, device=x.device, dtype=x.dtype
)
x = torch.cat([x, new_axis], dim=-1)
x = l2normalize(x, axis=-1)
x = self.w(x)
@ -170,7 +172,7 @@ class HyperTanhPolicy(nn.Module):
# Mean path
mean = self.mean_w1(x)
mean = self.mean_scaler(mean)
mean = self.mean_w2(mean) + self.mean_bias
mean = self.mean_w2(mean) + self.mean_bias.to(mean.dtype)
mean = torch.tanh(mean)
return mean
@ -197,7 +199,7 @@ class HyperCategoricalValue(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
logits = self.w1(x)
logits = self.scaler(logits)
logits = self.w2(logits) + self.bias
logits = self.w2(logits) + self.bias.to(logits.dtype)
return logits