Skip to content

Commit

Permalink
[RLlib; new API stack by default] Switch on new API stack by default …
Browse files Browse the repository at this point in the history
…for SAC and DQN. (#47217)
  • Loading branch information
sven1977 committed Sep 26, 2024
1 parent eebfdc2 commit 63233ec
Show file tree
Hide file tree
Showing 91 changed files with 1,014 additions and 5,401 deletions.
4 changes: 2 additions & 2 deletions doc/source/rllib/doc_code/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import tempfile

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.utils.checkpoints import convert_to_msgpack_checkpoint


# Base config used for both pickle-based checkpoint and msgpack-based one.
config = DQNConfig().environment("CartPole-v1")
config = PPOConfig().environment("CartPole-v1").env_runners(num_env_runners=0)
# Build algorithm object.
algo1 = config.build()

Expand Down
36 changes: 30 additions & 6 deletions doc/source/rllib/doc_code/replay_buffer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,31 @@


# __sphinx_doc_replay_buffer_type_specification__begin__
config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer})
config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(replay_buffer_config={"type": ReplayBuffer})
)

another_config = DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"})
another_config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(replay_buffer_config={"type": "ReplayBuffer"})
)


yet_another_config = DQNConfig().training(
replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
yet_another_config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(
replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
)
)

validate_buffer_config(config)
Expand Down Expand Up @@ -75,13 +93,16 @@ def sample(

config = (
DQNConfig()
.training(replay_buffer_config={"type": LessSampledReplayBuffer})
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.environment(env="CartPole-v1")
.training(replay_buffer_config={"type": LessSampledReplayBuffer})
)

tune.Tuner(
"DQN",
param_space=config.to_dict(),
param_space=config,
run_config=air.RunConfig(
stop={"training_iteration": 1},
),
Expand Down Expand Up @@ -127,6 +148,9 @@ def sample(
# __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__begin__
config = (
DQNConfig()
.api_stack(
enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
)
.training(
replay_buffer_config={
"type": "MultiAgentReplayBuffer",
Expand Down
49 changes: 35 additions & 14 deletions doc/source/rllib/doc_code/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,31 @@
# __query_action_dist_start__
# Get a reference to the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.framework("tf2")
.env_runners(num_env_runners=0)
.build()
)
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
# <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>

policy = algo.get_policy()
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>

# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
logits, _ = policy.model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
logits, _ = policy.model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])

# Compute action distribution given logits
Expand All @@ -57,14 +65,14 @@
# Query the distribution for samples, sample logps
dist.sample()
# <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
dist.logp([1])
dist.logp(torch.tensor([1]))
# <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>

# Get the estimated values for the most recent forward pass
policy.model.value_function()
# <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>

policy.model.base_model.summary()
print(policy.model)
"""
Model: "model"
_____________________________________________________________________
Expand Down Expand Up @@ -95,23 +103,36 @@
# __get_q_values_dqn_start__
# Get a reference to the model through the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = DQNConfig().environment("CartPole-v1").framework("tf2").build()
algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.environment("CartPole-v1")
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>

# List of all model variables
model.variables()
list(model.parameters())

# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is
# examples/offline_rl/saving_experiences.py
model_out = model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
model_out = model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)

# Access the base Keras models (all default models have a base)
model.base_model.summary()
print(model)
"""
Model: "model"
_______________________________________________________________________
Expand All @@ -132,16 +153,16 @@
"""

# Access the Q value model (specific to DQN)
print(model.get_q_value_distributions(model_out)[0])
print(model.get_q_value_distributions(model_out[0])[0])
# tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
# ^ exact numbers may differ due to randomness

model.q_value_head.summary()
print(model.advantage_module)

# Access the state value model (specific to DQN)
print(model.get_state_value(model_out))
print(model.get_state_value(model_out[0]))
# tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
# ^ exact number may differ due to randomness

model.state_value_head.summary()
print(model.value_module)
# __get_q_values_dqn_end__
Loading

0 comments on commit 63233ec

Please sign in to comment.