Extensions¶
Random Network Distillation¶
- class rsl_rl.extensions.rnd.RandomNetworkDistillation[source]¶
Implementation of Random Network Distillation (RND).
References
Schwarke et al. “Curiosity-Driven Learning of Joint Locomotion and Manipulation Tasks.” CoRL (2023).
Burda et al. “Exploration by Random Network Distillation.” arXiv preprint arXiv:1810.12894 (2018).
- __init__(num_states, obs_groups, num_outputs, predictor_hidden_dims, target_hidden_dims, activation='elu', state_normalization=False, reward_normalization=False, weight=0.0, weight_schedule=None, learning_rate=0.001, device='cpu')[source]¶
Initialize the RND module.
If
state_normalizationis True, then the input state is normalized using an Empirical Normalization layer.If
reward_normalizationis True, then the intrinsic reward is normalized using an Empirical Discounted Variation Normalization layer.If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states is used as the hidden dimension.
- Parameters:
num_states (int) – Number of states/inputs to the predictor and target networks.
obs_groups (dict) – Dictionary of observation groups.
num_outputs (int) – Number of outputs (embedding size) of the predictor and target networks.
predictor_hidden_dims (tuple[int, ...] | list[int]) – List of hidden dimensions of the predictor network.
target_hidden_dims (tuple[int, ...] | list[int]) – List of hidden dimensions of the target network.
activation (str) – Activation function.
state_normalization (bool) – Whether to normalize the input state.
reward_normalization (bool) – Whether to normalize the intrinsic reward.
weight (float) – Scaling factor of the intrinsic reward.
weight_schedule (dict | None) –
Type of schedule to use for the RND weight parameter. It is a dictionary with the following keys:
- ”mode”: Type of schedule to use for the RND weight parameter.
”constant”: Constant weight schedule.
”step”: Step weight schedule.
”linear”: Linear weight schedule.
For the “step” weight schedule, the following parameters are required:
”final_step”: Step at which the weight parameter is set to the final value.
”final_value”: Final value of the weight parameter.
For the “linear” weight schedule, the following parameters are required: - “initial_step”: Step at which the weight parameter is set to the initial value. - “final_step”: Step at which the weight parameter is set to the final value. - “final_value”: Final value of the weight parameter.
learning_rate (float) – Learning rate for the RND optimizer.
device (str) – Device to use.
- Return type:
None
- get_intrinsic_reward(obs)[source]¶
Compute weighted intrinsic rewards from prediction error in embedding space.
- Parameters:
obs (tensordict.TensorDict)
- Return type:
torch.Tensor
- compute_loss(obs)[source]¶
Compute the predictor loss (MSE between predicted and target embeddings).
- Parameters:
obs (tensordict.TensorDict)
- Return type:
torch.Tensor
- forward(*args, **kwargs)[source]¶
Disallow generic forward calls for this module.
- Parameters:
args (Any)
kwargs (dict[str, Any])
- Return type:
NoReturn
- train(mode=True)[source]¶
Set training mode for predictor and optional normalizers.
- Parameters:
mode (bool)
- Return type:
- rsl_rl.extensions.rnd.resolve_rnd_config(alg_cfg, obs, obs_groups, env)[source]¶
Resolve the RND configuration.
- Parameters:
alg_cfg (dict) – Algorithm configuration dictionary.
obs (tensordict.TensorDict) – Observation dictionary.
obs_groups (dict[str, list[str]]) – Observation groups dictionary.
env (VecEnv) – Environment object.
- Returns:
The resolved algorithm configuration dictionary.
- Return type:
dict
Symmetry¶
- class rsl_rl.extensions.symmetry.Symmetry[source]¶
Symmetry data augmentation and mirror loss.
The extension supports two (optionally simultaneous) uses of a user-provided symmetry function:
use_data_augmentationappends mirrored observation/action pairs to every mini-batch, so that the policy and value loss are evaluated on both the original and the mirrored samples.use_mirror_lossadds an auxiliary MSE term that penalizes the policy for disagreeing with itself when evaluated on mirrored observations.
If both flags are disabled the symmetry loss is still computed for logging purposes but detached from the graph.
References
Mittal et al. “Symmetry Considerations for Learning Task Symmetric Robot Policies.” ICRA (2024).
- __init__(env, data_augmentation_func, use_data_augmentation=False, use_mirror_loss=False, mirror_loss_coeff=0.0)[source]¶
Initialize the symmetry extension.
- Parameters:
env (VecEnv) – Environment object. Passed to the data augmentation function for handling different observation terms.
data_augmentation_func (str | Callable) – Callable that generates mirrored observations / actions. Resolved using
resolve_callable().use_data_augmentation (bool) – Whether to append mirrored samples to every mini-batch.
use_mirror_loss (bool) – Whether to add an auxiliary mirror loss term to the loss function.
mirror_loss_coeff (float) – Scaling factor applied to the mirror loss when
use_mirror_lossis True.
- Return type:
None
- augment_batch(batch, original_batch_size)[source]¶
Augment the mini-batch in place with mirrored observations and actions.
After the call
batch.observationsandbatch.actionshave shape[original_batch_size * num_aug, ...]with the original samples in the first slice and the mirrored samples in the remaining slices. The remaining rollout tensors (old log probabilities, values, advantages, returns) are repeated to match.When
use_data_augmentationis False, the batch is left unchanged.- Parameters:
batch (Batch)
original_batch_size (int)
- Return type:
None
- compute_loss(actor, batch, original_batch_size)[source]¶
Compute the mirror loss between the actor’s action means on original and mirrored observations.
If
augment_batch()has not been called for this batch (i.e.use_data_augmentationis False), the observations are augmented here first so that the actor is evaluated on both the original and the mirrored samples.The returned loss is detached when
use_mirror_lossis False so that it can be reported for logging without contributing to gradients.