reppo/reppo_alg/jaxrl/utils.py
2025-07-21 18:31:20 -04:00

137 lines
4.5 KiB
Python

import distrax
import flax
import jax
import jax.numpy as jnp
def describe(values: jnp.ndarray, axis: tuple | int = 0) -> dict[str, jnp.ndarray]:
"""Compute basic statistics for a batch of values."""
return {
"mean": jnp.mean(values, axis=axis),
"std": jnp.std(values, axis=axis),
"min": jnp.min(values, axis=axis),
"max": jnp.max(values, axis=axis),
}
def merge_dicts(*prefix_dicts: tuple[str, dict], sep: str = "/") -> dict:
"""Merge metric dictionaries with a prefix for each key."""
return {
f"{prefix if prefix else ''}{sep if prefix else ''}{key}": value
for prefix, metrics in prefix_dicts
for key, value in metrics.items()
}
def prefix_dict(prefix: str, metrics: dict, sep: str = "/") -> dict:
"""Add a prefix to all keys in a dictionary."""
return {f"{prefix}{sep}{key}": value for key, value in metrics.items()}
def postfix_dict(postfix: str, metrics: dict, sep: str = "/") -> dict:
"""Add a postfix to all keys in a dictionary."""
return {f"{key}{sep}{postfix}": value for key, value in metrics.items()}
def filter_prefix(prefix: str, metrics: dict, sep: str = "/") -> dict:
"""Filter keys in a dictionary by a prefix."""
return {
key: value for key, value in metrics.items() if key.startswith(prefix + sep)
}
def hl_gauss(inp, num_bins, vmin, vmax, epsilon=0.0):
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
x = jnp.clip(inp, vmin, max=vmax).squeeze() / (1 - epsilon)
bin_width = (vmax - vmin) / (num_bins - 1)
sigma_to_final_sigma_ratio = 0.75
support = jnp.linspace(
vmin - bin_width / 2, vmax + bin_width / 2, num_bins + 1, dtype=jnp.float32
)
sigma = bin_width * sigma_to_final_sigma_ratio
cdf_evals = jax.scipy.special.erf((support - x) / (jnp.sqrt(2) * sigma))
z = cdf_evals[-1] - cdf_evals[0]
target_probs = cdf_evals[1:] - cdf_evals[:-1]
target_probs = (target_probs / z).reshape(*inp.shape[:-1], num_bins)
uniform = jnp.ones_like(target_probs) / num_bins
return (1 - epsilon) * target_probs + epsilon * uniform
@flax.struct.dataclass
class MultiSampleLogProb:
policy_action: jax.Array
policy_action_log_prob: jax.Array
action: jax.Array
def fast_multi_log_prob(
key: jax.Array,
loc: jax.Array,
scale: jax.Array,
offset_scale: jax.Array,
) -> MultiSampleLogProb:
"""Computes 3 samples from a tanh squashed function
- transformed loc and log_prob
- sample with base scale
- sample with scaled scale
Args:
key: JAX PRNG key.
loc: Location of the distribution.
scale: Scale parameter of the distribution.
offset_scale: Offset scale for the distribution.
"""
# log det factor
# sample base gaussian noise with log prob
base_noise, base_log_prob = distrax.Normal(
jnp.zeros_like(loc), scale
).sample_and_log_prob(seed=key)
base_log_prob = jnp.sum(base_log_prob, axis=-1)
# sample with base scale
base_sample = loc + base_noise
base_sample_transformed = jnp.tanh(base_sample)
# numerically stable jax tanh det jacobian https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
base_log_prob -= jnp.sum(
2.0 * (jnp.log(2.0) - base_sample - jax.nn.softplus(-2.0 * base_sample)),
axis=-1,
)
return MultiSampleLogProb(
policy_action=base_sample_transformed,
policy_action_log_prob=base_log_prob,
action=jnp.tanh(loc + offset_scale * base_noise),
)
def multi_softmax(x, dim=8, get_logits=False):
inp_shape = x.shape
if dim is not None:
x = x.reshape(*x.shape[:-1], -1, dim)
if get_logits:
x = jax.nn.log_softmax(x, axis=-1)
else:
x = jax.nn.softmax(x, axis=-1)
return x.reshape(*inp_shape)
def multi_log_softmax(x, dim=8):
if dim is not None:
x = x.reshape(*x.shape[:-1], -1, dim)
return jax.nn.log_softmax(x).reshape(x.shape)
else:
return jax.nn.log_softmax(x, axis=-1)
def simplical_softmax_cross_entropy(pred, target, dim=8):
"""Computes the cross-entropy loss for simplical softmax."""
shape = pred.shape[-1]
if dim is not None:
pred = pred.reshape(*pred.shape[:-1], -1, dim)
target = target.reshape(*target.shape[:-1], -1, dim)
return jnp.sum(-target * jax.nn.log_softmax(pred, axis=-1), axis=-1).mean() / (
shape / dim
)