RLHF (Beta)

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback.

Overview

Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to:

RLHF using Axolotl

Important

This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.

We rely on the TRL library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.

Tip

You can find what each method supports by going into src/axolotl/prompt_strategies/{method} where {method} is one of our supported methods. The type: can be retrieved from {method}.{function_name}.

DPO

Example config:

rl: dpo
datasets:
  - path: Intel/orca_dpo_pairs
    split: train
    type: chatml.intel
  - path: argilla/ultrafeedback-binarized-preferences
    split: train
    type: chatml

DPO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

chatml.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "chosen_response": "...",
    "rejected_response": "..."
}

llama3.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.icr

{
    "system": "...", // optional
    "input": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

zephyr.nectar

{
    "prompt": "...",
    "answers": [
        {
            "answer": "...",
            "rank": 1
        },
        {
            "answer": "...",
            "rank": 2
        }
        // ... more answers with ranks
    ]
}

chat_template.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

chat_template.default

rl: dpo
datasets:
  - path: ...
    split: train
    type: chat_template.default
    field_messages: "messages"
    field_chosen: "chosen"
    field_rejected: "rejected"
    message_property_mappings:
      role: role
      content: content
    roles:
      user: ["user"]
      assistant: ["assistant"]
      system: ["system"]

Sample input format:

{
    "messages": [
        {
            "role": "system",
            "content": "..."
        },
        {
            "role": "user",
            "content": "..."
        },
        // ... more messages
    ],
    "chosen": {
        "role": "assistant",
        "content": "..."
    },
    "rejected": {
        "role": "assistant",
        "content": "..."
    }
}

user_defined.default

For custom behaviors,

rl: dpo
datasets:
  - path: ...
    split: train
    type:
      field_prompt: "prompt"
      field_system: "system"
      field_chosen: "chosen"
      field_rejected: "rejected"
      prompt_format: "{prompt}"
      chosen_format: "{chosen}"
      rejected_format: "{rejected}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}

IPO

As IPO is just DPO with a different loss function, all supported dataset formats for DPO are also supported for IPO.

rl: ipo

ORPO

Paper: https://arxiv.org/abs/2403.07691

rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false

chat_template: chatml
datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned
    type: chat_template.argilla

ORPO supports the following types with the following dataset format:

chat_template.argilla

{
    "system": "...",  // optional
    "prompt": "...",  // if available, will be taken as user message for single-turn instead of from list below

    // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
    "chosen": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ],
    "rejected": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

KTO

rl: kto
rl_beta: 0.1  # default
kto_desirable_weight: 1.0  # default
kto_undesirable_weight: 1.0  # default

remove_unused_columns: false

datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
    type: llama3.ultra
    split: train

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true

KTO supports the following types with the following dataset format:

chatml.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

chatml.argilla_chat

{
    "chosen": [
        {"role": "user", "content": "..."}
    ],
    "completion": [
        {"role": "assistant", "content": "..."}
    ]
}

chatml.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

chatml.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

chatml.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.argilla

{
    "system": "...", // optional
    "instruction": "...",
    "completion": "..."
}

llama3.argilla_chat

{
    "completion": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."}
    ]
}

llama3.intel

{
    "system": "...", // optional
    "question": "...",
    "completion": "..."
}

llama3.prompt_pairs

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

llama3.ultra

{
    "system": "...", // optional
    "prompt": "...",
    "completion": "..."
}

user_defined.default

For custom behaviors,

rl: kto
datasets:
  - path: ...
    split: train
    type:
      field_prompt: "prompt"
      field_system: "system"
      field_completion: "completion"
      field_label: "label"
      prompt_format: "{prompt}"
      completion_format: "{completion}"

The input format is a simple JSON input with customizable fields based on the above config.

{
    "system": "...",  // optional
    "prompt": "...",
    "completion": "...",
    "label": "..."
}

GRPO

Tip

Check out our GRPO cookbook.

In the latest GRPO implementation, vLLM is used to significantly speedup trajectory generation during training. In this example, we’re using 4 GPUs - 2 for training, and 2 for vLLM:

Important

Make sure you’ve installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. pip install axolotl[vllm].

base_model: Qwen/Qwen2.5-1.5B-Instruct

vllm:
    host: 0.0.0.0
    port: 8000
    tensor_parallel_size: 2
    gpu_memory_utilization: 0.85
    dtype: auto
    # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand

rl: grpo
trl:
    use_vllm: true
    vllm_server_host: 0.0.0.0
    vllm_server_port: 8000
    vllm_server_timeout: 300
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml

Your vLLM instance will now attempt to spin up, and it’s time to kick off training utilizing our remaining two GPUs. In another terminal, execute:

CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
Note

Due to TRL’s implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use CUDA_VISIBLE_DEVICES=2,3 for the vLLM instance.

Reward functions

GRPO uses custom reward functions and transformations. Please have them ready locally.

For example, to load OpenAI’s GSM8K and use a random reward for completions:

# rewards.py
import random

def rand_reward_func(completions, **kwargs) -> list[float]:
    return [random.uniform(0, 1) for _ in completions]

def oai_gsm8k_transform(cfg, *args, **kwargs):
    def transform_fn(example, tokenizer=None):
        label = example["answer"].split("####")[-1].strip().replace(",", "")
        return {
            "prompt": [{"role": "user", "content": example["question"]},],
            "answer": label,
        }
    return transform_fn, {"remove_columns": ["question"]}
rl: grpo

trl:
    beta: 0.001
    max_completion_length: 256
    use_vllm: True
    num_generations: 4
    reward_funcs: ["rewards.rand_reward_func"]    # format: '{file_name}.{fn_name}'
    reward_weights: [1.0]
datasets:
  - path: openai/gsm8k
    name: main
    type: rewards.oai_gsm8k_transform  # format: '{file_name}.{fn_name}'

To see other examples of custom reward functions, please see TRL GRPO Docs.

To see all configs, please see TRLConfig.

OpenEnv Rollout Functions

GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.

For example, to implement a simple math-solving environment with step-by-step verification:

# math_env.py
import re

def math_solver_rollout(model, processing_class, prompts, generation_config=None):
    """
    Custom rollout function that generates step-by-step math solutions.

    Args:
        model: The language model
        processing_class: The tokenizer/processing_class
        prompts: List of prompt dicts (with 'messages' key for chat format)
        generation_config: Optional generation configuration

    Returns:
        List of completion strings
    """
    completions = []

    for prompt in prompts:
        # Apply chat template to prompt
        messages = prompt.get("messages", [])
        formatted_prompt = processing_class.apply_chat_template(
            messages, processing_class=False, add_generation_prompt=True
        )

        # Generate step-by-step solution
        full_response = ""
        for step in range(5):  # Max 5 reasoning steps
            current_input = formatted_prompt + full_response + "\nNext step:"
            inputs = processing_class(current_input, return_tensors="pt").to(model.device)

            outputs = model.generate(
                **inputs,
                max_new_tokens=100,
                generation_config=generation_config,
            )
            step_text = processing_class.decode(
                outputs[0][inputs.input_ids.shape[1]:],
                skip_special_tokens=True
            )

            # Check if solution is complete
            if "FINAL ANSWER:" in step_text:
                full_response += step_text
                break
            full_response += step_text + "\n"

        completions.append(full_response)

    return completions

def math_reward(prompts, completions, answers, **kwargs):
    """Reward function that checks mathematical correctness"""
    rewards = []
    for completion, correct_answer in zip(completions, answers):
        # Extract predicted answer
        match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
        predicted = match.group(1).strip() if match else ""

        # Compare with correct answer
        reward = 1.0 if predicted == str(correct_answer) else 0.0
        rewards.append(reward)

    return rewards

def math_transform(cfg, *args, **kwargs):
    """Transform dataset to GRPO format with answer field"""
    def transform_fn(example, processing_class=None):
        return {
            "prompt": [{"role": "user", "content": example["question"]}],
            "answer": str(example["answer"]),
        }
    return transform_fn, {"remove_columns": ["question"]}
rl: grpo

trl:
  beta: 0.001
  max_completion_length: 512
  num_generations: 4
  rollout_func: "math_env.math_solver_rollout"  # Custom rollout function
  reward_funcs: ["math_env.math_reward"]
  reward_weights: [1.0]

datasets:
  - path: openai/gsm8k
    name: main
    type: math_env.math_transform

The rollout_func parameter accepts a fully qualified name (e.g., module_name.function_name) that points to a callable function in your local directory. The function receives:

  • model: The language model
  • processing_class: The tokenizer/processing class
  • prompts: List of prompt dictionaries
  • generation_config (optional): Generation configuration

And should return a list of completion strings.

For more OpenEnv examples, see TRL OpenEnv Documentation.

GRPO with DAPO/Dr. GRPO loss

The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.

trl:
  loss_type: dr_grpo
  # Normalizes loss based on max completion length (default: 256)
  max_completion_length:

For more information, see GRPO docs.

Async GRPO

Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.

trl:
  use_data_producer: true     # Enable data producer protocol
  use_vllm: true
  async_prefetch: true         # Generate rollouts in background thread
  prefetch_depth: 1            # Number of rollouts to prefetch
  vllm_sync_interval: 2        # Sync weights to vLLM every N steps
Note

Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by vllm_importance_sampling_correction: true (default when async is enabled).

vLLM LoRA Sync

By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.

adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true

trl:
  vllm_lora_sync: true         # Enable native LoRA sync

When vllm_lora_sync: true is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:

CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

Then start training on a separate GPU:

CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
Tip

LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.

Streaming Partial Batch

Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.

trl:
  streaming_partial_batch: true
Importance Sampling Correction

When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.

trl:
  vllm_importance_sampling_correction: true   # Enable IS correction
  importance_sampling_level: token             # 'token' or 'sequence'
  off_policy_mask_threshold: 0.5              # Mask sequences with IS ratio below this
  • importance_sampling_level: token applies per-token IS ratios (recommended with Liger kernel)
  • importance_sampling_level: sequence applies per-sequence IS ratios
  • off_policy_mask_threshold masks out sequences where the IS ratio indicates they are too far off-policy
Replay Buffer

The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.

trl:
  replay_buffer_size: 100       # Max cached groups (0 = disabled)
  replay_recompute_logps: true  # Recompute log-probs for replayed data (recommended)
Note

When replay_recompute_logps: true (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.

Deferred Re-rolling

Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.

trl:
  reroll_start_fraction: 0.5    # Start re-rolling after 50% of training
  reroll_max_groups: 1          # Max groups to replace per batch
Zero-Advantage Batch Skipping

When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as skipped_zero_adv_batches=1.

trl:
  skip_zero_advantage_batches: true   # default
Parallel Reward Workers

Reward functions that use signal.alarm() (e.g., math_verify) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.

trl:
  reward_num_workers: 4         # Number of subprocess workers (1 = no parallelism)
Full Async GRPO Example
base_model: Qwen/Qwen2.5-1.5B-Instruct

vllm:
    host: 0.0.0.0
    port: 8000
    gpu_memory_utilization: 0.35
    dtype: auto

adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true

rl: grpo
trl:
  use_data_producer: true
  use_vllm: true
  async_prefetch: true
  prefetch_depth: 1
  vllm_sync_interval: 2
  vllm_lora_sync: true
  streaming_partial_batch: true
  vllm_importance_sampling_correction: true
  off_policy_mask_threshold: 0.5
  importance_sampling_level: token
  num_generations: 8
  max_completion_length: 512
  reward_funcs:
    - rewards.accuracy_reward
  reroll_start_fraction: 0.5
  replay_buffer_size: 100
  reward_num_workers: 4
  skip_zero_advantage_batches: true

datasets:
  - path: AI-MO/NuminaMath-TIR
    type: rewards.prompt_transform
    split: train

gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
Multi-GPU Async GRPO

Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.

FSDP:

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
  use_reentrant: false

DeepSpeed ZeRO-3:

deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
  use_reentrant: true   # Required for ZeRO-3
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
Important

With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.

GDPO

GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the reward advantage collapse problem by normalizing each reward function independently before combining them.

Tip

Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.

Paper: https://arxiv.org/pdf/2501.05242

GDPO uses TRL’s native multi_objective_aggregation parameter under the hood. When you set rl: gdpo, axolotl automatically configures TRL to use normalize_then_sum aggregation.

base_model: Qwen/Qwen2.5-1.5B-Instruct

vllm:
    host: 0.0.0.0
    port: 8000
    tensor_parallel_size: 2
    gpu_memory_utilization: 0.85

rl: gdpo

trl:
    beta: 0.001
    max_completion_length: 256
    use_vllm: true
    num_generations: 4
    reward_funcs:
        - rewards.format_reward
        - rewards.correctness_reward
    reward_weights: [1.0, 2.0]

datasets:
    - path: openai/gsm8k
      name: main
      type: rewards.oai_gsm8k_transform

You can also use GRPO with explicit aggregation control:

rl: grpo
trl:
    multi_objective_aggregation: normalize_then_sum  # GDPO behavior
    # or: sum_then_normalize  # Default GRPO behavior

GDPO vs GRPO

Aspect GRPO GDPO
Aggregation sum_then_normalize normalize_then_sum
Multi-reward May collapse advantages Preserves reward signals
Single reward Standard behavior Equivalent to GRPO

Why GDPO?

When using multiple rewards with GRPO, different reward combinations can produce identical advantages:

# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3  ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3

GDPO normalizes each reward independently, preserving their relative differences.

Reward Functions

GDPO uses the same reward function format as GRPO:

# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
    return [1.0 if len(c) > 10 else 0.0 for c in completions]

def correctness_reward(completions, answers, **kwargs) -> list[float]:
    rewards = []
    for completion, answer in zip(completions, answers):
        # Your scoring logic here
        rewards.append(score)
    return rewards

Sequence Parallelism

GDPO supports sequence parallelism for long-context training:

rl: gdpo
context_parallel_size: 2

SimPO

SimPO uses CPOTrainer but with alternative loss function.

rl: simpo
rl_beta: 0.1  # default in CPOTrainer
cpo_alpha: 1.0  # default in CPOTrainer
simpo_gamma: 0.5  # default in CPOTrainer

This method uses the same dataset format as DPO.

EBFT

EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a feature-matching loss rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.

Paper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026)

Key advantages:

  • No reward model or verifier required — works on any (prompt, completion) data
  • Applicable to non-verifiable tasks (code, translation, creative writing)
  • Operates on model rollouts (not teacher forcing), reducing distribution shift

EBFT supports two modes:

  • Structured mode: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO).
  • Strided mode: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed.

Structured Mode

base_model: Qwen/Qwen3-4B

rl: ebft

ebft:
  feature_layers: [0.25, 0.5, 0.75]    # Extract features at 25%, 50%, 75% depth
  embed_method: last_token
  use_whitening: false
  alignment_coef: 1.0                    # Cosine similarity reward weight
  diversity_coef: 1.0                    # Pairwise dot product penalty
  ce_coef: 0.0                          # Cross-entropy on GT tokens (0 = off)

trl:
  num_generations: 4
  max_completion_length: 256
  temperature: 0.7
  use_vllm: true
  vllm_server_host: 0.0.0.0
  vllm_server_port: 8000
  vllm_lora_sync: true                   # LoRA adapter sync (recommended)
  vllm_sync_interval: 3
  use_data_producer: true
  async_prefetch: true                   # Set false for sync mode
  scale_rewards: true
  loss_type: grpo
  epsilon: 0.2

vllm:
  gpu_memory_utilization: 0.5
  max_model_len: 2048

datasets:
  - path: nvidia/OpenCodeInstruct
    type: ebft_opencode.transform
    split: train[:500]

adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
# Terminal 1: Start vLLM
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml

# Terminal 2: Train
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml

Strided Mode

For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU.

base_model: meta-llama/Llama-3.2-1B

rl: ebft

ebft:
  mode: strided
  stride: 8
  context_length: 8
  generate_max_len: 8
  n_samples_per_prompt: 4
  temperature: 0.6
  feature_layers: [0.25, 0.5, 0.75]
  embed_method: last_token
  use_whitening: true
  alignment_coef: 1.0
  diversity_coef: 1.0
  rl_coef: 1.0
  ce_coef: 0.03
  advantage_estimator: rloo

datasets:
  - path: nvidia/OpenCodeInstruct
    type: ebft_strided_structured.transform
    split: train[:1%]

flash_attention: false
flex_attention: true     # Strided mode uses flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true    # Required for flex_attention
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
Tip

See examples/ebft/ for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes.

EBFT Configuration Reference

Parameter Default Description
ebft.feature_layers [0.25, 0.5, 0.75] Layer depths for feature extraction (fractional)
ebft.embed_method last_token Feature pooling: last_token, mean_pooling, concat
ebft.use_whitening false SVD whitening of feature dimensions
ebft.alignment_coef 1.0 Cosine similarity reward weight
ebft.diversity_coef 1.0 Pairwise dot product penalty weight
ebft.ce_coef 0.0 Cross-entropy loss on ground-truth tokens
ebft.mode structured structured (vLLM) or strided (no vLLM)
ebft.stride Tokens between anchor points (strided mode)
ebft.context_length Context window per block (strided mode)
ebft.generate_max_len Tokens to generate per block (strided mode)
ebft.n_samples_per_prompt Rollouts per document (strided mode)
ebft.advantage_estimator grpo grpo or rloo (strided mode)

NeMo Gym Integration

NeMo Gym provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both single-turn (call /verify after generation) and multi-turn (agent-based tool execution via /run).

Single-Turn (Simplest)

For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls /verify directly on the resource server.

base_model: Qwen/Qwen2.5-0.5B-Instruct

rl: grpo
chat_template: tokenizer_default

trl:
  use_vllm: false                          # Colocate mode (single GPU)
  num_generations: 4
  max_completion_length: 128
  temperature: 0.9
  reward_funcs:
    - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify

plugins:
  - axolotl.integrations.nemo_gym.NemoGymPlugin

nemo_gym_enabled: true
nemo_gym_dir: ~/Gym
nemo_gym_auto_start: false
nemo_gym_head_port: 11000
nemo_gym_datasets:
  - path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
    server_name: reasoning_gym

datasets:
  - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
    type: chat_template
    field_messages: responses_create_params.input
    message_field_content: content
    message_field_role: role
# Terminal 1: Start NeMo Gym resource server
cd ~/Gym && .venv/bin/ng_run \
    "+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \
    "+skip_venv_if_present=true"

# Terminal 2: Train
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
Note

nemo_gym_datasets.path is relative to nemo_gym_dir. Don’t use absolute paths or they will be double-joined.

NeMo Gym Prerequisites

# Clone and set up NeMo Gym
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
cd ~/Gym
uv venv --python 3.12 && source .venv/bin/activate && uv sync

# Fix pycosat build (GCC 13+)
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation

NeMo Gym Configuration Reference

Parameter Type Default Description
nemo_gym_enabled bool Enable the NeMo Gym integration
nemo_gym_dir str ~/Gym Path to NeMo Gym repo
nemo_gym_auto_start bool true Auto-start resource servers
nemo_gym_head_port int 11000 Head server port
nemo_gym_multi_turn bool false Enable multi-turn via agent /run
nemo_gym_verify_timeout int 30 Per-request timeout (seconds)
nemo_gym_datasets list required Dataset configs with path and server_name

Reward Functions

Function Mode Description
axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify Single-turn Calls /verify, returns binary reward
axolotl.integrations.nemo_gym.rewards.reward_env Multi-turn Passthrough reward from agent /run

Using local dataset files

datasets:
  - ds_type: json
    data_files:
      - orca_rlhf.jsonl
    split: train
    type: chatml.intel

TRL auto-unwrapping for PEFT

TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:

# load ref model when adapter training.
rl_adapter_ref_model: true