Source code for rsl_rl.utils.wandb_utils

# Copyright (c) 2021-2026, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause


from __future__ import annotations

import os
import pathlib
from dataclasses import asdict
from torch.utils.tensorboard import SummaryWriter

try:
    import wandb
except ModuleNotFoundError:
    raise ModuleNotFoundError("wandb package is required to log to Weights and Biases.") from None


[docs] class WandbSummaryWriter(SummaryWriter): """Summary writer for W&B."""
[docs] def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: """Initialize a W&B run for logging.""" super().__init__(log_dir, flush_secs=flush_secs) # Get the run name run_name = os.path.split(log_dir)[-1] # Get wandb project and entity try: project = cfg["wandb_project"] except KeyError: raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") from None try: entity = os.environ["WANDB_USERNAME"] except KeyError: entity = None # Initialize wandb wandb.init( project=project, entity=entity, name=run_name, config={"log_dir": log_dir}, settings=wandb.Settings(start_method="thread"), ) # Initialize set to keep track of logged videos self.logged_videos: set[str] = set()
[docs] def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: """Upload environment and training configuration to W&B.""" wandb.config.update({"train_cfg": train_cfg}) try: wandb.config.update({"env_cfg": env_cfg.to_dict()}) # type: ignore except Exception: wandb.config.update({"env_cfg": asdict(env_cfg)}) # type: ignore
[docs] def add_scalar( self, tag: str, scalar_value: float, global_step: int | None = None, walltime: float | None = None, new_style: bool = False, ) -> None: """Log a scalar to both TensorBoard and W&B.""" super().add_scalar( tag, scalar_value, global_step=global_step, walltime=walltime, new_style=new_style, ) wandb.log({tag: scalar_value}, step=global_step)
[docs] def stop(self) -> None: """Finish the active W&B run.""" wandb.finish()
[docs] def save_model(self, model_path: str, it: int) -> None: """Upload a model checkpoint artifact to W&B.""" wandb.save(model_path, base_path=os.path.dirname(model_path))
[docs] def save_file(self, path: str) -> None: """Upload an arbitrary file artifact to W&B.""" wandb.save(path, base_path=os.path.dirname(path))
[docs] def save_video(self, video: pathlib.Path, it: int) -> None: """Upload a video artifact once per filename to W&B.""" if video.name not in self.logged_videos: wandb.log({"video": wandb.Video(str(video), format="mp4")}, step=it) self.logged_videos.add(video.name)