"""
Random Search (PRS) optimizer.
Andriushchenko et al., "Jailbreaking Leading Safety-Aligned LLMs with Simple
Adaptive Attacks" (2024). https://arxiv.org/abs/2404.02151
Zeroth-order (gradient-free) token optimizer that mutates contiguous blocks
of tokens. A coarse-to-fine schedule shrinks the block size over time
(exploration -> exploitation), and a patience mechanism triggers random
restarts when the search stalls.
"""
from __future__ import annotations
import logging
from typing import Literal, Optional
import torch
from jaxtyping import Int
from torch import Tensor
from tropt.common import DEFAULT_INIT_TRIGGER, Targets, TextTemplates
from tropt.loss import BaseLoss
from tropt.model import BaseModel, LossTextAccessMixin
from tropt.model.model_base import BaseTokenizer
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.optimizer.utils.token_initializers import get_printable_random_trigger
logger = logging.getLogger(__name__)
[docs]
class RandomSearchOptimizer(BaseOptimizer):
"""RandomSearch: batched zeroth-order token optimization with block mutation.
Per step:
1. Compute block size from coarse-to-fine schedule
2. For each candidate, pick a random start position and replace a
contiguous block with random tokens from the allowed set
3. Decode candidates to strings and evaluate via ``compute_loss_from_texts``
4. Keep best if it improves current loss
5. If no improvement for ``patience`` steps, restart from random init
Implementation Notes:
- Candidate evaluation is always text-based (``compute_loss_from_texts``); even for HF model,
we decode to strings and re-encode for model input.
- A tokenizer is needed for the optimizer's token-level mutations; it should either be provided, or we fall back to the model's tokenizer if it has one.
- The original implementation employs a "warm" initial trigger (eg another GCG suffix), and uses it as the starting point for all restarts. Here, we sample random triggers for all restarts for diversity.
- The original implementation employs an LLM judge for early stopping; here we use a simple patience counter for restarts.
- The original implementation mostly use a loss-based scheduler. For generality (e.g., different potential loss values) we avoid using it.
Reference implementation:
- The original implementation: https://github.com/tml-epfl/llm-adaptive-attacks/blob/main/main.py
- Another (more simplified) implementation: https://github.com/romovpa/claudini/blob/main/claudini/methods/original/prs/optimizer.py
"""
model_requirements = (LossTextAccessMixin,)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker=None,
seed: Optional[int] = None,
# optimization parameters:
num_steps: int = 500,
n_candidates: int = 128,
mutation_mode: Literal["block_random", "single_cyclic"] = "block_random",
# Block parameters
schedule: Literal["fixed", "none"] = "fixed",
initial_block_len: int = 4,
# misc:
patience: int = 25,
token_constraints: TokenConstraints = TokenConstraints(),
# external tokenizer (required when model has no tokenizer):
tokenizer: Optional[BaseTokenizer] = None,
):
"""
Args:
num_steps: Total optimization steps.
n_candidates: Number of mutated candidates per step.
mutation_mode:
``"block_random"`` for random contiguous block mutation
(original PRS),
``"single_cyclic"`` for single-token mutations
spread across positions (candidate ``i`` mutates position
``i % trigger_len``).
schedule: Schedule for block size decay. Relevant for block mutation mode(s).
``"fixed"`` for step-based coarse-to-fine decay,
``"none"`` to keep block size constant.
initial_block_len: Initial block size for mutation. Relevant for block mutation mode(s).
patience: Restart from random init after this many steps without
improvement. Set to 0 to disable restarts.
tokenizer: Tokenizer for encoding/decoding trigger tokens.
If None, uses model's tokenizer if it has one, else raises an error.
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
assert schedule in ("fixed", "none"), "Unsupported schedule type"
self.num_steps = num_steps
self.n_candidates = n_candidates
self.mutation_mode = mutation_mode
# Block parameters:
self.initial_block_len = initial_block_len
self.schedule = schedule
self.patience = patience
self.token_constraints = token_constraints
# Resolve tokenizer: prefer explicit, fall back to model's
if tokenizer is not None:
self.tokenizer = tokenizer
elif hasattr(model, "tokenizer"):
self.tokenizer = model.tokenizer
else:
raise ValueError(
"No tokenizer available. Pass an external tokenizer when the "
"model does not expose one."
)
# ------------------------------------------------------------------ #
# Main entry point
# ------------------------------------------------------------------ #
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
targets: Optional[Targets] = None,
) -> OptimizerResult:
# --- Setup ---
self.model.set_inputs_from_texts(templates=templates, targets=targets)
tokenizer = self.tokenizer
device = self.model.device
trigger_ids: Int[Tensor, "trigger_seq_len"] = (
tokenizer.encode_trigger(initial_trigger).to(device)
)
trigger_len = trigger_ids.shape[0]
valid_token_ids = self.token_constraints.get_whitelist_ids(
tokenizer, tokenizer.vocab_size, device, return_tensor=True
)
n_valid = len(valid_token_ids) # noqa: F841
best = RunningBest()
# Initial loss
trigger_str = tokenizer.decode_trigger(trigger_ids)
current_loss = self.model.compute_loss_from_texts(
[trigger_str], loss_func=self.loss_func
).item()
self.log(loss=current_loss, trigger_str=trigger_str)
best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str)
# Restart bookkeeping
steps_without_improvement = 0 # consecutive steps without improvement
step_of_curr_restart = 0 # the step of the most recent restart
restart_count = 0 # number of restarts (for logging purposes)
# --- Optimization loop ---
for step_i in self.track_steps(range(self.num_steps)):
# Check patience — restart if stuck
if self.patience > 0 and steps_without_improvement >= self.patience:
restart_count += 1
trigger_ids = get_printable_random_trigger(
trigger_len, return_ids=True, tokenizer=tokenizer
).to(device)
trigger_str = tokenizer.decode_trigger(trigger_ids)
current_loss = self.model.compute_loss_from_texts(
[trigger_str], loss_func=self.loss_func
).item()
steps_without_improvement = 0
step_of_curr_restart = step_i
logger.info(
"PRS restart #%d at step %d (loss=%.4f)",
restart_count, step_i, current_loss,
)
# --- Sample candidates ---
candidates = trigger_ids.unsqueeze(0).expand(self.n_candidates, -1).clone()
if self.mutation_mode == "single_cyclic":
block_len = self._mutate_single_cyclic(
candidates, valid_token_ids
)
else: # default to block_random
local_step = step_i - step_of_curr_restart
block_len = self._mutate_block_random(
candidates, valid_token_ids, local_step, trigger_len
)
# --- Evaluate candidates (as texts) ---
candidate_strs = tokenizer.decode_triggers(candidates)
losses = self.model.compute_loss_from_texts(
candidate_strs, loss_func=self.loss_func
)
# If improved, update current trigger; else increment patience counter
best_idx = losses.argmin()
candidate_loss = losses[best_idx].item()
if candidate_loss < current_loss:
trigger_ids = candidates[best_idx]
current_loss = candidate_loss
steps_without_improvement = 0
else:
steps_without_improvement += 1
# logging stuff:
trigger_str = tokenizer.decode_trigger(trigger_ids)
self.log(
loss=current_loss,
trigger_str=trigger_str,
n_token_flip=block_len,
restarts=restart_count,
patience_counter=steps_without_improvement,
)
best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str)
# --- Finalize ---
return best.to_result()
# ------------------------------------------------------------------ #
# Mutation strategies
# ------------------------------------------------------------------ #
def _mutate_single_cyclic(
self,
candidates: Int[Tensor, "n_candidates trigger_len"],
valid_token_ids: Int[Tensor, "n_valid"],
) -> int:
"""Single-token mutation spread across positions round-robin.
Candidate ``i`` mutates position ``i % trigger_len``.
Returns block_len (always 1).
"""
B = candidates.shape[0]
device = candidates.device
n_valid = len(valid_token_ids)
arange_B = torch.arange(B, device=device)
positions = arange_B % candidates.shape[1]
random_tokens = valid_token_ids[
torch.randint(n_valid, (B,), device=device)
]
candidates[arange_B, positions] = random_tokens
return 1
def _mutate_block_random(
self,
candidates: Int[Tensor, "n_candidates trigger_len"],
valid_token_ids: Int[Tensor, "n_valid"],
local_step: int,
trigger_len: int,
) -> int:
"""Contiguous block mutation at random positions.
Returns the block length used.
"""
B = candidates.shape[0]
device = candidates.device
n_valid = len(valid_token_ids)
arange_B = torch.arange(B, device=device)
block_len = min(self._get_curr_block_len(local_step), trigger_len)
max_start = trigger_len - block_len
block_starts = torch.randint(0, max_start + 1, (B,), device=device)
random_tokens = valid_token_ids[
torch.randint(n_valid, (B, block_len), device=device)
]
for offset in range(block_len):
candidates[arange_B, block_starts + offset] = random_tokens[:, offset]
return block_len
# ------------------------------------------------------------------ #
# Coarse-to-fine block-len schedule
# ------------------------------------------------------------------ #
def _get_curr_block_len(self, local_step: int) -> int:
"""Block size at a given step within the current restart.
Follows ``schedule_n_to_change_fixed`` from the paper's official code.
"""
if self.schedule == "none":
return self.initial_block_len
m = self.initial_block_len
if local_step <= 10:
return m
elif local_step <= 25:
return max(m // 2, 1)
elif local_step <= 50:
return max(m // 4, 1)
elif local_step <= 100:
return max(m // 8, 1)
elif local_step <= 500:
return max(m // 16, 1)
else:
return max(m // 32, 1)