137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
import distrax
|
|
import flax
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
|
|
def describe(values: jnp.ndarray, axis: tuple | int = 0) -> dict[str, jnp.ndarray]:
|
|
"""Compute basic statistics for a batch of values."""
|
|
return {
|
|
"mean": jnp.mean(values, axis=axis),
|
|
"std": jnp.std(values, axis=axis),
|
|
"min": jnp.min(values, axis=axis),
|
|
"max": jnp.max(values, axis=axis),
|
|
}
|
|
|
|
|
|
def merge_dicts(*prefix_dicts: tuple[str, dict], sep: str = "/") -> dict:
|
|
"""Merge metric dictionaries with a prefix for each key."""
|
|
return {
|
|
f"{prefix if prefix else ''}{sep if prefix else ''}{key}": value
|
|
for prefix, metrics in prefix_dicts
|
|
for key, value in metrics.items()
|
|
}
|
|
|
|
|
|
def prefix_dict(prefix: str, metrics: dict, sep: str = "/") -> dict:
|
|
"""Add a prefix to all keys in a dictionary."""
|
|
return {f"{prefix}{sep}{key}": value for key, value in metrics.items()}
|
|
|
|
|
|
def postfix_dict(postfix: str, metrics: dict, sep: str = "/") -> dict:
|
|
"""Add a postfix to all keys in a dictionary."""
|
|
return {f"{key}{sep}{postfix}": value for key, value in metrics.items()}
|
|
|
|
|
|
def filter_prefix(prefix: str, metrics: dict, sep: str = "/") -> dict:
|
|
"""Filter keys in a dictionary by a prefix."""
|
|
return {
|
|
key: value for key, value in metrics.items() if key.startswith(prefix + sep)
|
|
}
|
|
|
|
|
|
def hl_gauss(inp, num_bins, vmin, vmax, epsilon=0.0):
|
|
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
|
x = jnp.clip(inp, vmin, max=vmax).squeeze() / (1 - epsilon)
|
|
bin_width = (vmax - vmin) / (num_bins - 1)
|
|
sigma_to_final_sigma_ratio = 0.75
|
|
support = jnp.linspace(
|
|
vmin - bin_width / 2, vmax + bin_width / 2, num_bins + 1, dtype=jnp.float32
|
|
)
|
|
sigma = bin_width * sigma_to_final_sigma_ratio
|
|
cdf_evals = jax.scipy.special.erf((support - x) / (jnp.sqrt(2) * sigma))
|
|
z = cdf_evals[-1] - cdf_evals[0]
|
|
target_probs = cdf_evals[1:] - cdf_evals[:-1]
|
|
target_probs = (target_probs / z).reshape(*inp.shape[:-1], num_bins)
|
|
|
|
uniform = jnp.ones_like(target_probs) / num_bins
|
|
|
|
return (1 - epsilon) * target_probs + epsilon * uniform
|
|
|
|
|
|
@flax.struct.dataclass
|
|
class MultiSampleLogProb:
|
|
policy_action: jax.Array
|
|
policy_action_log_prob: jax.Array
|
|
action: jax.Array
|
|
|
|
|
|
def fast_multi_log_prob(
|
|
key: jax.Array,
|
|
loc: jax.Array,
|
|
scale: jax.Array,
|
|
offset_scale: jax.Array,
|
|
) -> MultiSampleLogProb:
|
|
"""Computes 3 samples from a tanh squashed function
|
|
- transformed loc and log_prob
|
|
- sample with base scale
|
|
- sample with scaled scale
|
|
Args:
|
|
key: JAX PRNG key.
|
|
loc: Location of the distribution.
|
|
scale: Scale parameter of the distribution.
|
|
offset_scale: Offset scale for the distribution.
|
|
"""
|
|
# log det factor
|
|
|
|
# sample base gaussian noise with log prob
|
|
base_noise, base_log_prob = distrax.Normal(
|
|
jnp.zeros_like(loc), scale
|
|
).sample_and_log_prob(seed=key)
|
|
base_log_prob = jnp.sum(base_log_prob, axis=-1)
|
|
|
|
# sample with base scale
|
|
base_sample = loc + base_noise
|
|
base_sample_transformed = jnp.tanh(base_sample)
|
|
# numerically stable jax tanh det jacobian https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
|
|
base_log_prob -= jnp.sum(
|
|
2.0 * (jnp.log(2.0) - base_sample - jax.nn.softplus(-2.0 * base_sample)),
|
|
axis=-1,
|
|
)
|
|
|
|
return MultiSampleLogProb(
|
|
policy_action=base_sample_transformed,
|
|
policy_action_log_prob=base_log_prob,
|
|
action=jnp.tanh(loc + offset_scale * base_noise),
|
|
)
|
|
|
|
|
|
def multi_softmax(x, dim=8, get_logits=False):
|
|
inp_shape = x.shape
|
|
if dim is not None:
|
|
x = x.reshape(*x.shape[:-1], -1, dim)
|
|
if get_logits:
|
|
x = jax.nn.log_softmax(x, axis=-1)
|
|
else:
|
|
x = jax.nn.softmax(x, axis=-1)
|
|
return x.reshape(*inp_shape)
|
|
|
|
|
|
def multi_log_softmax(x, dim=8):
|
|
if dim is not None:
|
|
x = x.reshape(*x.shape[:-1], -1, dim)
|
|
return jax.nn.log_softmax(x).reshape(x.shape)
|
|
else:
|
|
return jax.nn.log_softmax(x, axis=-1)
|
|
|
|
|
|
def simplical_softmax_cross_entropy(pred, target, dim=8):
|
|
"""Computes the cross-entropy loss for simplical softmax."""
|
|
shape = pred.shape[-1]
|
|
if dim is not None:
|
|
pred = pred.reshape(*pred.shape[:-1], -1, dim)
|
|
target = target.reshape(*target.shape[:-1], -1, dim)
|
|
return jnp.sum(-target * jax.nn.log_softmax(pred, axis=-1), axis=-1).mean() / (
|
|
shape / dim
|
|
)
|