From 83907422a30fd1cc2c0a579130cc1ebd5cdbf327 Mon Sep 17 00:00:00 2001 From: Younggyo Seo Date: Mon, 7 Jul 2025 10:04:46 -0700 Subject: [PATCH] Improved AMP/torch.compile compatibility of SimbaV2 (#21) --- fast_td3/environments/mujoco_playground_env.py | 4 ---- fast_td3/fast_td3_simbav2.py | 10 ++++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/fast_td3/environments/mujoco_playground_env.py b/fast_td3/environments/mujoco_playground_env.py index 4aa785b..2200389 100644 --- a/fast_td3/environments/mujoco_playground_env.py +++ b/fast_td3/environments/mujoco_playground_env.py @@ -4,10 +4,6 @@ from mujoco_playground import wrapper_torch import jax import mujoco -jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") -jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) -jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) - class PlaygroundEvalEnvWrapper: def __init__(self, eval_env, max_episode_steps, env_name, num_eval_envs, seed): diff --git a/fast_td3/fast_td3_simbav2.py b/fast_td3/fast_td3_simbav2.py index 80ae4ff..0ca3820 100644 --- a/fast_td3/fast_td3_simbav2.py +++ b/fast_td3/fast_td3_simbav2.py @@ -28,7 +28,7 @@ class Scaler(nn.Module): self.forward_scaler = init / scale def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.scaler * self.forward_scaler * x + return self.scaler.to(x.dtype) * self.forward_scaler * x class HyperDense(nn.Module): @@ -97,7 +97,9 @@ class HyperEmbedder(nn.Module): self.c_shift = c_shift def forward(self, x: torch.Tensor) -> torch.Tensor: - new_axis = torch.full((*x.shape[:-1], 1), self.c_shift, device=x.device) + new_axis = torch.full( + (*x.shape[:-1], 1), self.c_shift, device=x.device, dtype=x.dtype + ) x = torch.cat([x, new_axis], dim=-1) x = l2normalize(x, axis=-1) x = self.w(x) @@ -170,7 +172,7 @@ class HyperTanhPolicy(nn.Module): # Mean path mean = self.mean_w1(x) mean = self.mean_scaler(mean) - mean = self.mean_w2(mean) + self.mean_bias + mean = self.mean_w2(mean) + self.mean_bias.to(mean.dtype) mean = torch.tanh(mean) return mean @@ -197,7 +199,7 @@ class HyperCategoricalValue(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: logits = self.w1(x) logits = self.scaler(logits) - logits = self.w2(logits) + self.bias + logits = self.w2(logits) + self.bias.to(logits.dtype) return logits