Source code for rsl_rl.models.mlp_model

# Copyright (c) 2021-2026, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause


from __future__ import annotations

import copy
import torch
import torch.nn as nn
from tensordict import TensorDict

from rsl_rl.modules import MLP, EmpiricalNormalization, HiddenState
from rsl_rl.modules.distribution import Distribution
from rsl_rl.utils import resolve_callable, unpad_trajectories


[docs] class MLPModel(nn.Module): """MLP-based neural model. This model uses a simple multi-layer perceptron (MLP) to process 1D observation groups. Observations can be normalized before being passed to the MLP. The output of the model can be either deterministic or stochastic, in which case a distribution module is used to sample the outputs. """ is_recurrent: bool = False """Whether the model contains a recurrent module."""
[docs] def __init__( self, obs: TensorDict, obs_groups: dict[str, list[str]], obs_set: str, output_dim: int, hidden_dims: tuple[int, ...] | list[int] = (256, 256, 256), activation: str = "elu", obs_normalization: bool = False, distribution_cfg: dict | None = None, ) -> None: """Initialize the MLP-based model. Args: obs: Observation Dictionary. obs_groups: Dictionary mapping observation sets to lists of observation groups. obs_set: Observation set to use for this model (e.g., "actor" or "critic"). output_dim: Dimension of the output. hidden_dims: Hidden dimensions of the MLP. activation: Activation function of the MLP. obs_normalization: Whether to normalize the observations before feeding them to the MLP. distribution_cfg: Configuration dictionary for the output distribution. If provided, the model outputs stochastic values sampled from the distribution. """ super().__init__() # Resolve observation groups and dimensions self.obs_groups, self.obs_dim = self._get_obs_dim(obs, obs_groups, obs_set) # Observation normalization self.obs_normalization = obs_normalization if obs_normalization: self.obs_normalizer = EmpiricalNormalization(self.obs_dim) else: self.obs_normalizer = torch.nn.Identity() # Distribution if distribution_cfg is not None: dist_class: type[Distribution] = resolve_callable(distribution_cfg.pop("class_name")) # type: ignore self.distribution: Distribution | None = dist_class(output_dim, **distribution_cfg) mlp_output_dim = self.distribution.input_dim else: self.distribution = None mlp_output_dim = output_dim # MLP self.mlp = MLP(self._get_latent_dim(), mlp_output_dim, hidden_dims, activation) # Initialize distribution-specific MLP weights if self.distribution is not None: self.distribution.init_mlp_weights(self.mlp)
[docs] def forward( self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None, stochastic_output: bool = False, ) -> torch.Tensor: """Forward pass of the MLP model. ..note:: The `stochastic_output` flag only has an effect if the model has a distribution (i.e., ``distribution_cfg`` was provided) and defaults to ``False``, meaning that even stochastic models will return deterministic outputs by default. """ # If observations are padded for recurrent training but the model is non-recurrent, unpad the observations obs = unpad_trajectories(obs, masks) if masks is not None and not self.is_recurrent else obs # Get MLP input latent latent = self.get_latent(obs, masks, hidden_state) # MLP forward pass mlp_output = self.mlp(latent) # If stochastic output is requested, update the distribution and sample from it, otherwise return MLP output if self.distribution is not None: if stochastic_output: self.distribution.update(mlp_output) return self.distribution.sample() return self.distribution.deterministic_output(mlp_output) return mlp_output
[docs] def get_latent( self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None ) -> torch.Tensor: """Build the model latent by concatenating and normalizing selected observation groups.""" # Select and concatenate observations obs_list = [obs[obs_group] for obs_group in self.obs_groups] latent = torch.cat(obs_list, dim=-1) # Normalize observations latent = self.obs_normalizer(latent) return latent
[docs] def reset(self, dones: torch.Tensor | None = None, hidden_state: HiddenState = None) -> None: """Reset the internal state for recurrent models (no-op).""" pass
[docs] def get_hidden_state(self) -> HiddenState: """Return the recurrent hidden state (``None`` for MLP).""" return None
[docs] def detach_hidden_state(self, dones: torch.Tensor | None = None) -> None: """Detach therecurrent hidden state for truncated backpropagation (no-op).""" pass
@property def output_mean(self) -> torch.Tensor: """Return the mean of the current output distribution.""" return self.distribution.mean @property def output_std(self) -> torch.Tensor: """Return the standard deviation of the current output distribution.""" return self.distribution.std @property def output_entropy(self) -> torch.Tensor: """Return the entropy of the current output distribution.""" return self.distribution.entropy @property def output_distribution_params(self) -> tuple[torch.Tensor, ...]: """Return raw parameters of the current output distribution.""" return self.distribution.params
[docs] def get_output_log_prob(self, outputs: torch.Tensor) -> torch.Tensor: """Compute log-probabilities of outputs under the current distribution.""" return self.distribution.log_prob(outputs)
[docs] def get_kl_divergence( self, old_params: tuple[torch.Tensor, ...], new_params: tuple[torch.Tensor, ...] ) -> torch.Tensor: """Compute KL divergence between two parameterizations of the distribution.""" return self.distribution.kl_divergence(old_params, new_params)
[docs] def as_jit(self) -> nn.Module: """Return a version of the model compatible with Torch JIT export.""" return _TorchMLPModel(self)
[docs] def as_onnx(self, verbose: bool) -> nn.Module: """Return a version of the model compatible with ONNX export.""" return _OnnxMLPModel(self, verbose)
[docs] def update_normalization(self, obs: TensorDict) -> None: """Update observation-normalization statistics from a batch of observations.""" if self.obs_normalization: # Select and concatenate observations obs_list = [obs[obs_group] for obs_group in self.obs_groups] mlp_obs = torch.cat(obs_list, dim=-1) # Update the normalizer parameters self.obs_normalizer.update(mlp_obs) # type: ignore
def _get_obs_dim(self, obs: TensorDict, obs_groups: dict[str, list[str]], obs_set: str) -> tuple[list[str], int]: """Select active observation groups and compute observation dimension.""" active_obs_groups = obs_groups[obs_set] obs_dim = 0 for obs_group in active_obs_groups: if len(obs[obs_group].shape) != 2: raise ValueError( f"The MLP model only supports 1D observations, got shape {obs[obs_group].shape} for '{obs_group}'." ) obs_dim += obs[obs_group].shape[-1] return active_obs_groups, obs_dim def _get_latent_dim(self) -> int: """Return the latent dimensionality consumed by the MLP head.""" return self.obs_dim
class _TorchMLPModel(nn.Module): """Exportable MLP model for JIT.""" def __init__(self, model: MLPModel) -> None: """Create a TorchScript-friendly copy of an MLPModel.""" super().__init__() self.obs_normalizer = copy.deepcopy(model.obs_normalizer) self.mlp = copy.deepcopy(model.mlp) if model.distribution is not None: self.deterministic_output = model.distribution.as_deterministic_output_module() else: self.deterministic_output = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: """Run deterministic inference on pre-concatenated observations.""" x = self.obs_normalizer(x) out = self.mlp(x) return self.deterministic_output(out) @torch.jit.export def reset(self) -> None: """Reset recurrent export state (no-op for MLP exports).""" pass class _OnnxMLPModel(nn.Module): """Exportable MLP model for ONNX.""" is_recurrent: bool = False def __init__(self, model: MLPModel, verbose: bool) -> None: """Create an ONNX-export wrapper around an MLPModel.""" super().__init__() self.verbose = verbose self.obs_normalizer = copy.deepcopy(model.obs_normalizer) self.mlp = copy.deepcopy(model.mlp) if model.distribution is not None: self.deterministic_output = model.distribution.as_deterministic_output_module() else: self.deterministic_output = nn.Identity() self.input_size = model.obs_dim def forward(self, x: torch.Tensor) -> torch.Tensor: """Run deterministic inference for ONNX export.""" x = self.obs_normalizer(x) out = self.mlp(x) return self.deterministic_output(out) def get_dummy_inputs(self) -> tuple[torch.Tensor]: """Return representative dummy inputs for ONNX tracing.""" return (torch.zeros(1, self.input_size),) @property def input_names(self) -> list[str]: """Return ONNX input tensor names.""" return ["obs"] @property def output_names(self) -> list[str]: """Return ONNX output tensor names.""" return ["actions"]