RL Dataloader
API reference for ReplayBuffer and MultiTurnDataloader
RL Dataloader
The MultiTurnDataloader handles parallel environment rollout, tokenization, and batching for RL training.
ReplayBuffer
Stores experience tuples with episode-aware reward discounting.
from cua_bench.workers import ReplayBuffer
buffer = ReplayBuffer(
capacity=10000, # Max experiences to store
gamma=0.9, # Discount factor
only_keep_outcome=False, # Keep all steps or just final
balance_thres=0.5, # Threshold for balance stats
)
# Add experiences (as tuple)
buffer.add((
0, # worker_id
{"obs": "...", "reward": 0.0, "done": False}, # env_ret
{"uid": "episode-123"} # meta_info
))
# Sample for training
samples = buffer.sample(batch_size=32)
# Get balance statistics
below, above = buffer.get_balance_stats()Reward Discounting
When an episode completes, rewards propagate backwards:
Episode: [step0, step1, step2 (done, reward=1.0)]
With gamma=0.9:
step2.reward = 1.0
step1.reward = 0.9
step0.reward = 0.81MultiTurnDataloader
Manages parallel environments and provides tokenized batches.
from cua_bench.workers import (
MultiTurnDataloader,
CBEnvWorkerClient,
create_workers,
)
workers = await create_workers(n_workers=4, allowed_ips=["127.0.0.1"])
task_configs = [{"env_path": "./task", "task_index": 0, "split": "train"}]
env_configs = [{
"server_url": w.api_url,
"task_configs": task_configs,
"max_step": 50,
"max_hist": 10,
"timeout": 300,
} for w in workers]
dataloader = MultiTurnDataloader(
env_class=CBEnvWorkerClient,
env_configs=env_configs,
tokenizer=tokenizer,
batch_size=4,
replay_capacity=10000,
replay_reward_discount=0.9,
max_prompt_length=1024,
max_response_length=256,
)
for batch in dataloader:
# batch: input_ids, attention_mask, position_ids, worker_id, meta_info
responses = model.generate(batch['input_ids'])
dataloader.async_step({
'prompts': batch['input_ids'],
'responses': responses,
'attention_mask': combined_mask,
'worker_id': batch['worker_id'],
'meta_info': batch['meta_info'],
})Constructor Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
env_class | Type | required | Environment client class |
env_configs | List[Dict] | required | Worker configs |
task_configs | List[Dict] | required | Task configurations |
tokenizer | Any | required | HuggingFace tokenizer |
processor | Any | None | HuggingFace processor for multimodal |
is_multi_modal | bool | False | Enable image processing |
batch_size | int | 8 | Must be less than or equal to num_envs |
replay_capacity | int | 10000 | Replay buffer size |
replay_reward_discount | float | 0.9 | Gamma for discounting |
max_prompt_length | int | 1024 | Max prompt tokens |
max_response_length | int | 1024 | Max response tokens |
only_keep_outcome_in_replay | bool | False | Only keep final steps |
Batch Format
From next(dataloader):
{
'input_ids': torch.Tensor, # (batch, seq_len)
'attention_mask': torch.Tensor, # (batch, seq_len)
'position_ids': torch.Tensor, # (batch, seq_len)
'worker_id': np.ndarray,
'meta_info': np.ndarray,
}batch_return Format
For async_step():
{
'prompts': torch.Tensor, # (batch, prompt_len)
'responses': torch.Tensor, # (batch, response_len)
'attention_mask': torch.Tensor, # (batch, total_len)
'worker_id': np.ndarray,
'meta_info': np.ndarray,
}Action Parsing
The dataloader parses action strings from responses:
<|action_start|>click(0.5, 0.5)<|action_end|>
<|action_start|>type("hello")<|action_end|>
<|action_start|>done()<|action_end|>Methods
# Get running reward (EMA)
reward = dataloader.running_outcome_reward()
# Get balance stats
below, above = dataloader.get_balance_stats()
# Sample from replay buffer
batch = dataloader.sample_from_buffer(batch_size=32)
# Print stats
dataloader.print_stats_in_replay_buffer()
# Clear replay buffer
dataloader.clear_replay_buffer()
# Cleanup workers
dataloader.close()Example
See the Train an Agent with GRPO tutorial for a complete working example.
Was this page helpful?