Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ray-project/ray into depr…
Browse files Browse the repository at this point in the history
…ecate_hybrid_api_stack

Signed-off-by: sven1977 <[email protected]>

# Conflicts:
#	doc/source/rllib/doc_code/training.py
#	rllib/algorithms/cql/cql.py
#	rllib/algorithms/dqn/dqn.py
#	rllib/algorithms/ppo/tests/test_ppo.py
#	rllib/algorithms/ppo/tests/test_ppo_old_api_stack.py
#	rllib/algorithms/sac/sac.py
  • Loading branch information
sven1977 committed Sep 26, 2024
2 parents 3e44e91 + 63233ec commit a049042
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
14 changes: 12 additions & 2 deletions doc/source/rllib/doc_code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.framework("torch")
.environment("CartPole-v1")
.env_runners(num_env_runners=0)
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
# <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>

Expand Down Expand Up @@ -108,8 +113,13 @@
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.framework("torch")
.environment("CartPole-v1")
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
Expand Down
20 changes: 19 additions & 1 deletion rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,26 @@ def __init__(self, algo_class=None):
self.lagrangian_thresh = 5.0
self.min_q_weight = 5.0
self.deterministic_backup = True
self.lr = 3e-4
# Note, the new stack defines learning rates for each component.
# The base learning rate `lr` has to be set to `None`, if using
# the new stack.
self.actor_lr = 1e-4
self.critic_lr = 1e-3
self.alpha_lr = 1e-3
self.lr = None

self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": int(1e6),
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
# Whether to compute priorities already on the remote worker side.
"worker_side_prioritization": False,
}

# Changes to Algorithm's/SACConfig's default:

Expand All @@ -103,6 +116,11 @@ def __init__(self, algo_class=None):
# .reporting()
self.min_sample_timesteps_per_iteration = 0
self.min_train_timesteps_per_iteration = 100
# `.api_stack()`
self.api_stack(
enable_rl_module_and_learner=False,
enable_env_runner_and_connector_v2=False,
)
# fmt: on
# __sphinx_doc_end__

Expand Down
20 changes: 9 additions & 11 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, algo_class=None):
# fmt: on
# __sphinx_doc_end__

# Deprecated.
# Deprecated
self.buffer_size = DEPRECATED_VALUE
self.prioritized_replay = DEPRECATED_VALUE
self.learning_starts = DEPRECATED_VALUE
Expand Down Expand Up @@ -424,16 +424,14 @@ def validate(self) -> None:
# Call super's validation method.
super().validate()

# Disallow hybrid API stack for DQN/SAC.
if (
self.enable_rl_module_and_learner
and not self.enable_env_runner_and_connector_v2
):
raise ValueError(
"Hybrid API stack (`enable_rl_module_and_learner=True` and "
"`enable_env_runner_and_connector_v2=False`) no longer supported for "
"SAC! Set both to True (recommended new API stack) or both to False "
"(old API stack)."
# Warn about new API stack on by default.
if self.enable_rl_module_and_learner:
logger.warning(
"You are running DQN on the new API stack! This is the new default "
"behavior for this algorithm. If you don't want to use the new API "
"stack, set `config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. For a detailed "
"migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)

if (
Expand Down
16 changes: 15 additions & 1 deletion rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(self, algo_class=None):
self.initial_alpha = 1.0
self.target_entropy = "auto"
self.n_step = 1

# Replay buffer configuration.
self.replay_buffer_config = {
"type": "PrioritizedEpisodeReplayBuffer",
# Size of the replay buffer. Note that if async_updates is set,
Expand All @@ -84,6 +86,7 @@ def __init__(self, algo_class=None):
# Beta parameter for sampling from prioritized replay buffer.
"beta": 0.4,
}

self.store_buffer_in_checkpoints = False
self.training_intensity = None
self.optimization = {
Expand Down Expand Up @@ -458,7 +461,10 @@ def validate(self) -> None:
isinstance(self.replay_buffer_config["type"], str)
and "Episode" in self.replay_buffer_config["type"]
)
or issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
or (
isinstance(self.replay_buffer_config["type"], type)
and issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
)
):
raise ValueError(
"When using the old API stack the replay buffer must not be of type "
Expand All @@ -479,6 +485,14 @@ def validate(self) -> None:
"and `alpha_lr`, for the actor, critic, and the hyperparameter "
"`alpha`, respectively and set `config.lr` to None."
)
# Warn about new API stack on by default.
logger.warning(
"You are running SAC on the new API stack! This is the new default "
"behavior for this algorithm. If you don't want to use the new API "
"stack, set `config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. For a detailed "
"migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)

@override(AlgorithmConfig)
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
Expand Down

0 comments on commit a049042

Please sign in to comment.