Extensions

Random Network Distillation

class rsl_rl.extensions.rnd.RandomNetworkDistillation[source]

Implementation of Random Network Distillation (RND).

References

  • Schwarke et al. “Curiosity-Driven Learning of Joint Locomotion and Manipulation Tasks.” CoRL (2023).

  • Burda et al. “Exploration by Random Network Distillation.” arXiv preprint arXiv:1810.12894 (2018).

__init__(num_states, obs_groups, num_outputs, predictor_hidden_dims, target_hidden_dims, activation='elu', state_normalization=False, reward_normalization=False, weight=0.0, weight_schedule=None, learning_rate=0.001, device='cpu')[source]

Initialize the RND module.

  • If state_normalization is True, then the input state is normalized using an Empirical Normalization layer.

  • If reward_normalization is True, then the intrinsic reward is normalized using an Empirical Discounted Variation Normalization layer.

  • If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states is used as the hidden dimension.

Parameters:
  • num_states (int) – Number of states/inputs to the predictor and target networks.

  • obs_groups (dict) – Dictionary of observation groups.

  • num_outputs (int) – Number of outputs (embedding size) of the predictor and target networks.

  • predictor_hidden_dims (tuple[int, ...] | list[int]) – List of hidden dimensions of the predictor network.

  • target_hidden_dims (tuple[int, ...] | list[int]) – List of hidden dimensions of the target network.

  • activation (str) – Activation function.

  • state_normalization (bool) – Whether to normalize the input state.

  • reward_normalization (bool) – Whether to normalize the intrinsic reward.

  • weight (float) – Scaling factor of the intrinsic reward.

  • weight_schedule (dict | None) –

    Type of schedule to use for the RND weight parameter. It is a dictionary with the following keys:

    • ”mode”: Type of schedule to use for the RND weight parameter.
      • ”constant”: Constant weight schedule.

      • ”step”: Step weight schedule.

      • ”linear”: Linear weight schedule.

    For the “step” weight schedule, the following parameters are required:

    • ”final_step”: Step at which the weight parameter is set to the final value.

    • ”final_value”: Final value of the weight parameter.

    For the “linear” weight schedule, the following parameters are required: - “initial_step”: Step at which the weight parameter is set to the initial value. - “final_step”: Step at which the weight parameter is set to the final value. - “final_value”: Final value of the weight parameter.

  • learning_rate (float) – Learning rate for the RND optimizer.

  • device (str) – Device to use.

Return type:

None

get_intrinsic_reward(obs)[source]

Compute weighted intrinsic rewards from prediction error in embedding space.

Parameters:

obs (tensordict.TensorDict)

Return type:

torch.Tensor

compute_loss(obs)[source]

Compute the predictor loss (MSE between predicted and target embeddings).

Parameters:

obs (tensordict.TensorDict)

Return type:

torch.Tensor

forward(*args, **kwargs)[source]

Disallow generic forward calls for this module.

Parameters:
  • args (Any)

  • kwargs (dict[str, Any])

Return type:

NoReturn

train(mode=True)[source]

Set training mode for predictor and optional normalizers.

Parameters:

mode (bool)

Return type:

RandomNetworkDistillation

eval()[source]

Set the module to evaluation mode.

Return type:

RandomNetworkDistillation

get_rnd_state(obs)[source]

Extract and concatenate observation groups used as theRND state.

Parameters:

obs (tensordict.TensorDict)

Return type:

torch.Tensor

update_normalization(obs)[source]

Update state-normalization statistics from observations.

Parameters:

obs (tensordict.TensorDict)

Return type:

None

rsl_rl.extensions.rnd.resolve_rnd_config(alg_cfg, obs, obs_groups, env)[source]

Resolve the RND configuration.

Parameters:
  • alg_cfg (dict) – Algorithm configuration dictionary.

  • obs (tensordict.TensorDict) – Observation dictionary.

  • obs_groups (dict[str, list[str]]) – Observation groups dictionary.

  • env (VecEnv) – Environment object.

Returns:

The resolved algorithm configuration dictionary.

Return type:

dict

Symmetry

class rsl_rl.extensions.symmetry.Symmetry[source]

Symmetry data augmentation and mirror loss.

The extension supports two (optionally simultaneous) uses of a user-provided symmetry function:

  • use_data_augmentation appends mirrored observation/action pairs to every mini-batch, so that the policy and value loss are evaluated on both the original and the mirrored samples.

  • use_mirror_loss adds an auxiliary MSE term that penalizes the policy for disagreeing with itself when evaluated on mirrored observations.

If both flags are disabled the symmetry loss is still computed for logging purposes but detached from the graph.

References

  • Mittal et al. “Symmetry Considerations for Learning Task Symmetric Robot Policies.” ICRA (2024).

__init__(env, data_augmentation_func, use_data_augmentation=False, use_mirror_loss=False, mirror_loss_coeff=0.0)[source]

Initialize the symmetry extension.

Parameters:
  • env (VecEnv) – Environment object. Passed to the data augmentation function for handling different observation terms.

  • data_augmentation_func (str | Callable) – Callable that generates mirrored observations / actions. Resolved using resolve_callable().

  • use_data_augmentation (bool) – Whether to append mirrored samples to every mini-batch.

  • use_mirror_loss (bool) – Whether to add an auxiliary mirror loss term to the loss function.

  • mirror_loss_coeff (float) – Scaling factor applied to the mirror loss when use_mirror_loss is True.

Return type:

None

augment_batch(batch, original_batch_size)[source]

Augment the mini-batch in place with mirrored observations and actions.

After the call batch.observations and batch.actions have shape [original_batch_size * num_aug, ...] with the original samples in the first slice and the mirrored samples in the remaining slices. The remaining rollout tensors (old log probabilities, values, advantages, returns) are repeated to match.

When use_data_augmentation is False, the batch is left unchanged.

Parameters:
  • batch (Batch)

  • original_batch_size (int)

Return type:

None

compute_loss(actor, batch, original_batch_size)[source]

Compute the mirror loss between the actor’s action means on original and mirrored observations.

If augment_batch() has not been called for this batch (i.e. use_data_augmentation is False), the observations are augmented here first so that the actor is evaluated on both the original and the mirrored samples.

The returned loss is detached when use_mirror_loss is False so that it can be reported for logging without contributing to gradients.

Parameters:
Return type:

torch.Tensor

rsl_rl.extensions.symmetry.resolve_symmetry_config(alg_cfg, env)[source]

Resolve the symmetry configuration.

Parameters:
  • alg_cfg (dict) – Algorithm configuration dictionary.

  • env (VecEnv) – Environment object.

Returns:

The resolved algorithm configuration dictionary.

Return type:

dict