diff --git a/src/art/__init__.py b/src/art/__init__.py index 01ccb231..3a8e048b 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -54,7 +54,6 @@ def __init__(self, **kwargs): from .backend import Backend from .batches import trajectory_group_batches from .gather import gather_trajectories, gather_trajectory_groups -from .local import LocalBackend from .model import Model, TrainableModel from .serverless import ServerlessBackend from .trajectories import Trajectory, TrajectoryGroup diff --git a/src/art/utils/benchmarking/log_constant_metrics_wandb.py b/src/art/utils/benchmarking/log_constant_metrics_wandb.py index 3f774601..ada24810 100644 --- a/src/art/utils/benchmarking/log_constant_metrics_wandb.py +++ b/src/art/utils/benchmarking/log_constant_metrics_wandb.py @@ -9,7 +9,7 @@ async def log_constant_metrics_wandb( model: art.Model, num_steps: int, split_metrics: dict[str, dict[str, float]], - model_name_appendix: str | None = None, + logged_run_name: str | None = None, ) -> None: """ Log constant metrics to W&B as horizontal lines across all training steps. @@ -32,7 +32,7 @@ async def log_constant_metrics_wandb( """ run = wandb.init( project=model.project, - name=model.name + f" {model_name_appendix}" if model_name_appendix else "", + name=logged_run_name if logged_run_name else model.name, reinit="create_new", )