diff --git a/doc/source/rllib/doc_code/training.py b/doc/source/rllib/doc_code/training.py index 5af76bfb48f2b..451bc664cbdf2 100644 --- a/doc/source/rllib/doc_code/training.py +++ b/doc/source/rllib/doc_code/training.py @@ -38,9 +38,14 @@ .api_stack( enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False ) - .environment("CartPole-v1") .framework("torch") + .environment("CartPole-v1") .env_runners(num_env_runners=0) + .training( + replay_buffer_config={ + "type": "MultiAgentPrioritizedReplayBuffer", + } + ) ).build() # @@ -108,8 +113,13 @@ .api_stack( enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False ) - .environment("CartPole-v1") .framework("torch") + .environment("CartPole-v1") + .training( + replay_buffer_config={ + "type": "MultiAgentPrioritizedReplayBuffer", + } + ) ).build() model = algo.get_policy().model # diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 838348d6e64cd..b16f67264234e 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -85,13 +85,26 @@ def __init__(self, algo_class=None): self.lagrangian_thresh = 5.0 self.min_q_weight = 5.0 self.deterministic_backup = True + self.lr = 3e-4 # Note, the new stack defines learning rates for each component. # The base learning rate `lr` has to be set to `None`, if using # the new stack. self.actor_lr = 1e-4 self.critic_lr = 1e-3 self.alpha_lr = 1e-3 - self.lr = None + + self.replay_buffer_config = { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": int(1e6), + # If True prioritized replay buffer will be used. + "prioritized_replay": False, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + # Whether to compute priorities already on the remote worker side. + "worker_side_prioritization": False, + } # Changes to Algorithm's/SACConfig's default: @@ -103,6 +116,11 @@ def __init__(self, algo_class=None): # .reporting() self.min_sample_timesteps_per_iteration = 0 self.min_train_timesteps_per_iteration = 100 + # `.api_stack()` + self.api_stack( + enable_rl_module_and_learner=False, + enable_env_runner_and_connector_v2=False, + ) # fmt: on # __sphinx_doc_end__ diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 81c395f7e0e01..dea2874752a84 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -211,7 +211,7 @@ def __init__(self, algo_class=None): # fmt: on # __sphinx_doc_end__ - # Deprecated. + # Deprecated self.buffer_size = DEPRECATED_VALUE self.prioritized_replay = DEPRECATED_VALUE self.learning_starts = DEPRECATED_VALUE @@ -424,16 +424,14 @@ def validate(self) -> None: # Call super's validation method. super().validate() - # Disallow hybrid API stack for DQN/SAC. - if ( - self.enable_rl_module_and_learner - and not self.enable_env_runner_and_connector_v2 - ): - raise ValueError( - "Hybrid API stack (`enable_rl_module_and_learner=True` and " - "`enable_env_runner_and_connector_v2=False`) no longer supported for " - "SAC! Set both to True (recommended new API stack) or both to False " - "(old API stack)." + # Warn about new API stack on by default. + if self.enable_rl_module_and_learner: + logger.warning( + "You are running DQN on the new API stack! This is the new default " + "behavior for this algorithm. If you don't want to use the new API " + "stack, set `config.api_stack(enable_rl_module_and_learner=False, " + "enable_env_runner_and_connector_v2=False)`. For a detailed " + "migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa ) if ( diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index 1d12084fb1f93..35a9b9cece329 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -75,6 +75,8 @@ def __init__(self, algo_class=None): self.initial_alpha = 1.0 self.target_entropy = "auto" self.n_step = 1 + + # Replay buffer configuration. self.replay_buffer_config = { "type": "PrioritizedEpisodeReplayBuffer", # Size of the replay buffer. Note that if async_updates is set, @@ -84,6 +86,7 @@ def __init__(self, algo_class=None): # Beta parameter for sampling from prioritized replay buffer. "beta": 0.4, } + self.store_buffer_in_checkpoints = False self.training_intensity = None self.optimization = { @@ -458,7 +461,10 @@ def validate(self) -> None: isinstance(self.replay_buffer_config["type"], str) and "Episode" in self.replay_buffer_config["type"] ) - or issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer) + or ( + isinstance(self.replay_buffer_config["type"], type) + and issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer) + ) ): raise ValueError( "When using the old API stack the replay buffer must not be of type " @@ -479,6 +485,14 @@ def validate(self) -> None: "and `alpha_lr`, for the actor, critic, and the hyperparameter " "`alpha`, respectively and set `config.lr` to None." ) + # Warn about new API stack on by default. + logger.warning( + "You are running SAC on the new API stack! This is the new default " + "behavior for this algorithm. If you don't want to use the new API " + "stack, set `config.api_stack(enable_rl_module_and_learner=False, " + "enable_env_runner_and_connector_v2=False)`. For a detailed " + "migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa + ) @override(AlgorithmConfig) def get_rollout_fragment_length(self, worker_index: int = 0) -> int: