Source code for rsl_rl.utils.wandb_utils
# Copyright (c) 2021-2026, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import os
import pathlib
from dataclasses import asdict
from torch.utils.tensorboard import SummaryWriter
try:
import wandb
except ModuleNotFoundError:
raise ModuleNotFoundError("wandb package is required to log to Weights and Biases.") from None
[docs]
class WandbSummaryWriter(SummaryWriter):
"""Summary writer for W&B."""
[docs]
def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None:
"""Initialize a W&B run for logging."""
super().__init__(log_dir, flush_secs=flush_secs)
# Get the run name
run_name = os.path.split(log_dir)[-1]
# Get wandb project and entity
try:
project = cfg["wandb_project"]
except KeyError:
raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") from None
try:
entity = os.environ["WANDB_USERNAME"]
except KeyError:
entity = None
# Initialize wandb
wandb.init(
project=project,
entity=entity,
name=run_name,
config={"log_dir": log_dir},
settings=wandb.Settings(start_method="thread"),
)
# Initialize set to keep track of logged videos
self.logged_videos: set[str] = set()
[docs]
def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None:
"""Upload environment and training configuration to W&B."""
wandb.config.update({"train_cfg": train_cfg})
try:
wandb.config.update({"env_cfg": env_cfg.to_dict()}) # type: ignore
except Exception:
wandb.config.update({"env_cfg": asdict(env_cfg)}) # type: ignore
[docs]
def add_scalar(
self,
tag: str,
scalar_value: float,
global_step: int | None = None,
walltime: float | None = None,
new_style: bool = False,
) -> None:
"""Log a scalar to both TensorBoard and W&B."""
super().add_scalar(
tag,
scalar_value,
global_step=global_step,
walltime=walltime,
new_style=new_style,
)
wandb.log({tag: scalar_value}, step=global_step)
[docs]
def stop(self) -> None:
"""Finish the active W&B run."""
wandb.finish()
[docs]
def save_model(self, model_path: str, it: int) -> None:
"""Upload a model checkpoint artifact to W&B."""
wandb.save(model_path, base_path=os.path.dirname(model_path))
[docs]
def save_file(self, path: str) -> None:
"""Upload an arbitrary file artifact to W&B."""
wandb.save(path, base_path=os.path.dirname(path))
[docs]
def save_video(self, video: pathlib.Path, it: int) -> None:
"""Upload a video artifact once per filename to W&B."""
if video.name not in self.logged_videos:
wandb.log({"video": wandb.Video(str(video), format="mp4")}, step=it)
self.logged_videos.add(video.name)