A Gentle Introduction to verl — Part 1
Wrangle and implement RL algorithms with confidence. A deep dive into verl's architecture — from master-worker design to the PPO training loop — so you can go beyond config files.
Wrangle and implement RL algorithms with confidence. A deep dive into verl's architecture — from master-worker design to the PPO training loop — so you can go beyond config files.

verl (Volcano Engine Reinforcement Learning) is a reinforcement learning optimization library used to post-train LLMs. It supports various RL techniques including:
If you're doing any kind of post-training on LLMs with RL, verl is one of the most capable open-source tools available.
RL for LLMs has become increasingly important for three reasons:
Data scarcity: There are works which show just a few seed data points and RL is enough for post-training. You don't need massive labeled datasets anymore.
Better generalization: Prolonged RL training helps create models that generalize better across tasks — not just memorize training distributions.
Verifiable domains: In areas like math, code, and puzzles, RLVR is incredibly useful because you can automatically check whether the model's output is correct.
This article assumes you already know how RL applies to LLMs at a conceptual level. If you're new to that, I recommend watching this RL for LLMs Tutorial before diving in.
When people pick up verl, I've noticed they fall into two camps:
This tutorial is for the second group. The goal is to make you feel comfortable navigating and extending a library that is incredibly useful but can be intimidating at first glance.
verl follows a master-worker architecture:
There's an important design tradeoff here. You could co-locate data and processes on the same machine (low data movement, but non-reusable code). verl chose the opposite: remote processing with higher data movement cost but highly reusable code. This makes it easy to swap RL algorithms, change model architectures, or scale to different cluster sizes without rewriting infrastructure.
Design of verl (recommended to read the section below before looking at this)
Most users start with configuration files in YAML format. The config files define everything from model paths to training hyperparameters — you can see the default PPO trainer config for reference.
# Format checks enforced on CI:
# 1. Comments must appear above each field.
# 2. There must be a blank line between each field.
# 3. Inline comments (after a field on the same line) are not allowed.
# 4. Indentation level is respected for nested fields.
# dataset config
data:
# Tokenizer class or path. If null, it will be inferred from the model.
tokenizer: null
# Whether to use shared memory for data loading.
use_shm: False
# Training set parquet. Can be a list or a single file.
# The program will read all files into memory, so it can't be too large (< 100GB).
# The path can be either a local path or an HDFS path.
train_files: ~/data/rlhf/gsm8k/train.parquet
# Validation parquet. Can be a list or a single file.
val_files: ~/data/rlhf/gsm8k/test.parquet
# The field in the dataset where the prompt is located. Default is 'prompt'.
prompt_key: prompt
But the real entry point into the codebase is main_ppo.py.
# main_ppo.py — top-level structure
import hydra
import ray
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.reward import load_reward_manager
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
...
# Define a function to run the PPO-like training process
def run_ppo(config) -> None:
...
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
...
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
...
The TaskRunner is where everything gets set up. Let's walk through its key components.
Standard HuggingFace tokenizer loaded from the model name. For multimodal models, processors handle image/audio inputs alongside text.
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
# Download the checkpoint from HDFS to the local machine.
# `use_shm` determines whether to use shared memory
local_path = copy_to_local(
config.actor_rollout_ref.model.path,
use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
# Instantiate the tokenizer and processor.
from
verl supports three distributed training frameworks:
Workers can be async (dormant unless actively used) or persistent (like ActorRolloutRefWorker, which stays alive throughout training).
# Define worker classes based on the actor strategy.
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
assert config.critic.strategy in ["fsdp", "fsdp2"]
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import (
ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
)
actor_rollout_cls = (
AsyncActorRolloutRefWorker
if config.actor_rollout_ref.rollout.mode == "async"
else ActorRolloutRefWorker
)
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
...
else:
...
This is where verl gets flexible. Say you have 16 GPUs — you can reserve 8 for the Actor and 8 for the Critic, or split them any way you want. A resource pool manager handles the allocation, letting you tune the compute balance between generation and training.
# Resource pool setup — allocates GPU groups for each role
resource_pool_manager = ResourcePoolManager(
resource_pool_spec=config.trainer.resource_pool_spec,
mapping=config.trainer.mapping,
)
# Set worker-to-role mapping based on available resource pools
role_worker_mapping = {
Role.ActorRollout: actor_rollout_cls,
Role.Critic: CriticWorker,
Role.RefPolicy: ActorRolloutRefWorker,
Role.RewardModel: RewardModelWorker,
}
After resource specs are configured, the trainer is created with all these components wired together:
from verl.utils.dataset.rl_dataset import collate_fn
# Create training and validation datasets.
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
# Initialize the PPO trainer.
trainer = RayPPOTrainer(...)
# Initialize the workers of the trainer.
trainer.init_workers()
# Start the training process.
trainer.fit()
The core training logic lives in RayPPOTrainer. Two methods matter most: init_workers() and fit().
class RayPPOTrainer:
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(self, ...):
...
def _validate_config(self): ...
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): ...
def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): ...
def _maybe_log_val_generations(self, inputs, outputs, scores): ...
def _validate(self): ...
def init_workers(self): ...
def _save_checkpoint(self): ...
def _load_checkpoint(self): ...
def _balance_batch(self, batch: DataProto, metrics, logging_prefix
This method sets up the distributed training environment:
RayClassWithInitArgs to convert regular Python classes into Ray-schedulable distributed classesdef init_workers(self):
"""Initialize distributed training workers using Ray backend.
Creates:
1. Ray resource pools from configuration
2. Worker groups for each role (actor, critic, etc.)
"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {
pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()
}
# create actor and rollout
if self.hybrid_engine:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.ActorRollout],
config=self.config.actor_rollout_ref,
role="actor_rollout",
)
self.resource_pool_to_cls[resource_pool][
Then the critic, reference policy, and reward model are set up the same way:
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref")
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
# create a reward model if reward_fn is None
if self
Worker groups are then spawned on the Ray cluster, and each is assigned to its role:
# Initialize WorkerGroups — spawn actual Ray actors
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
if self.use_critic:
self.critic_wg = all_wg["critic"]
self.critic_wg.init_model()
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg["ref"]
self
The RayWorkerGroup class manages the lifecycle of these distributed workers:
class RayWorkerGroup(WorkerGroup):
"""A group of Ray workers that can be managed collectively.
This class extends WorkerGroup to provide Ray-specific functionality for
creating and managing groups of Ray actors with specific resource requirements
and scheduling strategies.
"""
def __init__(self, ...) -> None: ...
def _is_worker_alive(self, worker: ray.actor.ActorHandle): ...
def _init_with_detached_workers(self, worker_names, worker_handles): ...
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): ...
@property
def worker_names(self): ...
@classmethod
def from_detached(cls, ...): ...
def spawn(self, prefix_set): ...
def spawn_fused(self, prefix_set): ...
Uses StatefulDataLoader, similar to PyTorch's standard data loading pattern. If you don't provide a custom dataloader, it defaults to RLHFDataset.
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
num_workers=self.config.data.get("dataloader_num_workers", 8),
drop_last=True,
collate_fn=collate_fn,
sampler=train_sampler,
)
The dataset creation wires the tokenizer and configuration together:
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
"""Create a dataset.
Arguments:
data_paths: List of paths to data files.
data_config: The data config.
tokenizer (Tokenizer): The tokenizer.
processor (Processor): The processor.
Returns:
dataset (Dataset): The dataset.
"""
from torch.utils.data import Dataset
from verl.utils.dataset.rl_dataset import RLHFDataset
# Check if a custom dataset class is specified in the data configuration
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
from verl.utils.import_utils import load_extern_type
# Dynamically load the custom dataset class
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
if not
RLHFDataset subclasses torch.utils.data.Dataset (standard PyTorch pattern). It downloads and tokenizes data, then applies chat templates to format conversations.
class RLHFDataset(Dataset):
"""
Load and preprocess RLHF data from Parquet files.
- Caches files locally.
- Reads into a HuggingFace Dataset and tokenizes prompts.
- Optionally handles images/videos via a ProcessorMixin.
- Filters prompts over a max length.
- Supports resuming from checkpoints.
"""
def __init__(self, ...): ...
def _download(self, use_origin_parquet=False): ...
def _read_files_and_tokenize(self): ...
def resume_dataset_state(self): ...
def __len__(self): ...
def _build_messages(self, example: dict): ...
def __getitem__(self, item): ...
def __getstate__(self): ...
In _read_files_and_tokenize, the dataset loads parquet files, concatenates them, and filters out prompts that are too long:
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.data_files:
# read parquet files and cache
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
dataframes.append(dataframe)
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
print(f"dataset len: {len(self.dataframe)}")
# filter out too long prompts
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe.filter(
lambda doc:
__getitem__ returns tokenized prompts with attention masks. The key step is postprocess_data which handles left-padding and truncation:
input_ids, attention_mask = verl_F.postprocess_data(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
Nothing surprising here — it's a clean, standard implementation.
This is the heart of the training loop. Here's what happens at each step:
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(...)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
...
# add tqdm
progress_bar = tqdm(total
DataProto — verl's enhanced dictionary type that carries metadata alongside tensors.# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")
gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
self.async_rollout_manager.wake_up()
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
self.async_rollout_manager.sleep()
timing_raw.update(gen_batch_output.meta_info.pop("timing", None))
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
REMAX baseline (optional): For REMAX algorithm, generate additional baseline responses with greedy decoding.
UID assignment: Assign unique IDs to each prompt for advantage calculation. This ensures correct grouping when computing baselines.
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
For algorithms like GRPO, the advantage is computed per group — the UID ensures prompts are correctly grouped:
def compute_grpo_outcome_advantage(
...
):
"""
If True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
Returns:
advantages: `torch.Tensor`
shape is (bs, response_length)
scores: `torch.Tensor`
shape is (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
Reward computation: Send generated sequences to Reward Model workers. Get back scalar reward scores.
Old log probabilities: Compute log probabilities under the current policy (before any updates). These become the "old" log probs for PPO's importance ratio.
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
# Extract the old log probs for later PPO ratio computation
if "old_log_prob" in batch.batch.keys():
...
batch.batch["old_log_prob"] = batch.batch["old_log_prob"].detach()
Rollout engine inference: Use vLLM or SGLang to compute current policy log probabilities efficiently.
Value estimation: Critic network estimates the value of each state for advantage computation.
Advantage calculation: Compute token-level advantage scores using the rewards and value estimates.
if self.use_reference_policy:
# compute reference log_prob
with marked_timer("ref", timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with marked_timer("adv", timing_raw, color="brown"):
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
print(batch.batch.keys())
inputs =
The inner update loop:
old_log_prob (computed once per epoch on the old policy) and current_log_prob (computed per batch on the updating policy) to compute the PPO losscurrent / old is what PPO clips to prevent too-large policy updatesif self.config.policy_loss.loss_mode == "vanilla":
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)
This separation — old log probs computed once, current log probs updated per batch — is fundamental to how PPO maintains stability during training.
verl's architecture makes more sense once you understand the flow:
fit() runs the actual RL steps: generate, score, compute advantage, updateThe master process stays lightweight. All the expensive operations — model inference, gradient computation, reward scoring — happen on remote workers managed by Ray.
In Part 2, we'll go hands-on: implementing the latest RL algorithms and getting them merged into verl. The goal is to move you from understanding the architecture to actively contributing to it.
If you're interested in LLM pretraining, post-training, and ML interviews — this is the kind of content I write regularly.
Originally published on ML Research Engineer Substack.