Extensions

Random Network Distillation

class rsl_rl.extensions.rnd.RandomNetworkDistillation[source]

Implementation of Random Network Distillation (RND).

References

  • Schwarke et al. “Curiosity-Driven Learning of Joint Locomotion and Manipulation Tasks.” CoRL (2023).

  • Burda et al. “Exploration by Random Network Distillation.” arXiv preprint arXiv:1810.12894 (2018).

__init__(num_states, obs_groups, num_outputs, predictor_hidden_dims, target_hidden_dims, activation='elu', weight=0.0, state_normalization=False, reward_normalization=False, device='cpu', weight_schedule=None)[source]

Initialize the RND module.

  • If state_normalization is True, then the input state is normalized using an Empirical Normalization layer.

  • If reward_normalization is True, then the intrinsic reward is normalized using an Empirical Discounted Variation Normalization layer.

  • If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states is used as the hidden dimension.

Parameters:
  • num_states (int) – Number of states/inputs to the predictor and target networks.

  • obs_groups (dict) – Dictionary of observation groups.

  • num_outputs (int) – Number of outputs (embedding size) of the predictor and target networks.

  • predictor_hidden_dims (tuple[int, ...] | list[int]) – List of hidden dimensions of the predictor network.

  • target_hidden_dims (tuple[int, ...] | list[int]) – List of hidden dimensions of the target network.

  • activation (str) – Activation function.

  • weight (float) – Scaling factor of the intrinsic reward.

  • state_normalization (bool) – Whether to normalize the input state.

  • reward_normalization (bool) – Whether to normalize the intrinsic reward.

  • device (str) – Device to use.

  • weight_schedule (dict | None) –

    Type of schedule to use for the RND weight parameter. It is a dictionary with the following keys:

    • ”mode”: Type of schedule to use for the RND weight parameter.
      • ”constant”: Constant weight schedule.

      • ”step”: Step weight schedule.

      • ”linear”: Linear weight schedule.

    For the “step” weight schedule, the following parameters are required:

    • ”final_step”: Step at which the weight parameter is set to the final value.

    • ”final_value”: Final value of the weight parameter.

    For the “linear” weight schedule, the following parameters are required: - “initial_step”: Step at which the weight parameter is set to the initial value. - “final_step”: Step at which the weight parameter is set to the final value. - “final_value”: Final value of the weight parameter.

Return type:

None

get_intrinsic_reward(obs)[source]

Compute weighted intrinsic rewards from prediction error in embedding space.

Parameters:

obs (tensordict.TensorDict)

Return type:

torch.Tensor

forward(*args, **kwargs)[source]

Disallow generic forward calls for this module.

Parameters:
  • args (Any)

  • kwargs (dict[str, Any])

Return type:

NoReturn

train(mode=True)[source]

Set training mode for predictor and optional normalizers.

Parameters:

mode (bool)

Return type:

RandomNetworkDistillation

eval()[source]

Set the module to evaluation mode.

Return type:

RandomNetworkDistillation

get_rnd_state(obs)[source]

Extract and concatenate observation groups used as theRND state.

Parameters:

obs (tensordict.TensorDict)

Return type:

torch.Tensor

update_normalization(obs)[source]

Update state-normalization statistics from observations.

Parameters:

obs (tensordict.TensorDict)

Return type:

None

rsl_rl.extensions.rnd.resolve_rnd_config(alg_cfg, obs, obs_groups, env)[source]

Resolve the RND configuration.

Parameters:
  • alg_cfg (dict) – Algorithm configuration dictionary.

  • obs (tensordict.TensorDict) – Observation dictionary.

  • obs_groups (dict[str, list[str]]) – Observation groups dictionary.

  • env (VecEnv) – Environment object.

Returns:

The resolved algorithm configuration dictionary.

Return type:

dict

Symmetry

rsl_rl.extensions.symmetry.resolve_symmetry_config(alg_cfg, env)[source]

Resolve the symmetry configuration.

Parameters:
  • alg_cfg (dict) – Algorithm configuration dictionary.

  • env (VecEnv) – Environment object.

Returns:

The resolved algorithm configuration dictionary.

Return type:

dict