Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 70 additions & 34 deletions src/proxyz/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@
help="Clustering files for cluster-based sampling. Each file has two columns: "
"cluster_id and data_row_id. Sampling weight = n / (1 + log(n)) where n is cluster size.",
)
@click.option(
"--eval_cluster_files",
type=click.Path(),
multiple=True,
help="Clustering files for eval per-sample loss weighting. Same format as --cluster_files.",
)
@click.option("-v", "--verbose", is_flag=True, help="verbose output.")
def main(**args):
args = dict2object(**args)
Expand Down Expand Up @@ -453,71 +459,97 @@ def tokenize_dataset(dataset):
print(f"--- Eval dataset ---")
print(f"Examples: {len(eval_dataset):,}")

# Load cluster information and compute sampling weights if cluster files provided
train_sampler = None
if args.cluster_files:
def _compute_cluster_weights(cluster_files, dataset_len):
"""Load cluster files and compute per-sample weights.

Returns (cluster_weights, cluster_map):
cluster_weights: dict mapping cluster_id -> weight (1 / (1 + log(n)))
cluster_map: dict mapping data_row_id -> cluster_id
"""

# Load all cluster files and build mapping: data_row_id -> cluster_id
cluster_map = {} # data_row_id -> cluster_id
for cluster_file in args.cluster_files:
for cluster_file in cluster_files:
with open(cluster_file, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
cluster_id = parts[0]
data_row_id = int(parts[1])
cluster_map[data_row_id] = cluster_id

# Count cluster sizes
cluster_sizes = defaultdict(int)
for data_row_id, cluster_id in cluster_map.items():
if data_row_id < len(train_dataset): # Only count valid indices
if data_row_id < dataset_len: # Only count valid indices
cluster_sizes[cluster_id] += 1

# Compute sampling weights: n / (1 + log(n)) for each cluster
cluster_weights = {}
for cluster_id, n in cluster_sizes.items():
cluster_weights[cluster_id] = 1 / (1 + math.log(n))

# Assign weights to each sample
sample_weights = []
for i in range(len(train_dataset)):
if i in cluster_map:
cluster_id = cluster_map[i]
weight = cluster_weights[cluster_id]
sample_weights.append(weight)
else:
# If no cluster info, use uniform weight (1.0)
sample_weights.append(1.0)

sample_weights = [
cluster_weights.get(cluster_map[i], 1.0) if i in cluster_map else 1.0
for i in range(dataset_len)
]

return sample_weights, cluster_map, cluster_sizes


# Load cluster information and compute sampling weights if cluster files provided
train_sampler = None
if args.cluster_files:
sample_weights, _, cluster_sizes = _compute_cluster_weights(
args.cluster_files, len(train_dataset)
)

# Create WeightedRandomSampler
train_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)

if args.verbose:
print(f"--- Cluster-based sampling ---")
print(f"Clusters: {len(cluster_sizes):,}")
print(f"Samples with cluster info: {sum(1 for w in sample_weights if w != 1.0):,}")

# Load eval cluster weights for per-sample loss weighting during evaluation
if args.eval_cluster_files and eval_dataset:
eval_sample_weights, _, eval_cluster_sizes = _compute_cluster_weights(
args.eval_cluster_files, len(eval_dataset)
)
eval_dataset = eval_dataset.add_column("sample_weight", eval_sample_weights)
if args.verbose:
print(f"--- Eval cluster weighting ---")
print(f"Eval clusters: {len(eval_cluster_sizes):,}")
print(f"Eval samples with weight info: {sum(1 for w in eval_sample_weights if w != 1.0):,}")

# Data collator that pads input_ids, attention_mask, and labels uniformly
pad_token_id = tokenizer.pad_token_id

def data_collator(examples):
max_len = max(len(ex["input_ids"]) for ex in examples)
input_ids, attention_mask, labels = [], [], []
sample_weights = []
for ex in examples:
pad_len = max_len - len(ex["input_ids"])
input_ids.append(ex["input_ids"] + [pad_token_id] * pad_len)
attention_mask.append(ex["attention_mask"] + [0] * pad_len)
label_seq = ex.get("labels", ex["input_ids"])
labels.append(label_seq + [-100] * pad_len)
return {
sample_weights.append(ex.get("sample_weight", 1.0))
batch = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
if any(w != 1.0 for w in sample_weights):
batch["sample_weights"] = torch.tensor(sample_weights, dtype=torch.float32)
return batch

# ==========================================
# 4. TRAINING ARGUMENTS & EXECUTION
Expand All @@ -527,14 +559,18 @@ def data_collator(examples):
class FIMTrainer(Trainer):
def __init__(self, train_sampler=None, **kwargs):
super().__init__(**kwargs)

self.train_sampler = train_sampler

def _get_train_sampler(self, train_dataset: Dataset = None):
if self.train_sampler is not None:
return self.train_sampler
return super()._get_train_sampler(train_dataset)

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# Extract sample weights for eval (not a model input)
sample_weights = inputs.pop("sample_weights", None)

# Always cache data for FIM loss tracking (training and eval)
# Cache data BEFORE calling super (which may modify inputs)
labels = inputs["labels"].detach().clone()
Expand All @@ -545,7 +581,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)

# Cache logits after forward pass
self._fim_cache = (labels, outputs.logits.detach().clone())
self._fim_cache = (labels, outputs.logits.detach().clone(), sample_weights)

return (loss, outputs) if return_outputs else loss

Expand All @@ -561,7 +597,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
if cache is None:
return
self.trainer._fim_cache = None
labels, logits = cache
labels, logits, sample_weights = cache

# Detect if this is eval or training based on log keys
is_eval = "eval_loss" in logs
Expand All @@ -577,12 +613,9 @@ def on_log(self, args, state, control, logs=None, **kwargs):
logs[f"{prefix}n_fim"] = n_fim
logs[f"{prefix}n_std"] = n_std

has_fim = is_fim.any().item()
has_std = (~is_fim).any().item()

if has_fim and has_std:
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
for tag, mask in [(f"{prefix}loss_fim", is_fim), (f"{prefix}loss_std", ~is_fim)]:
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
for tag, mask in [(f"{prefix}loss_fim", is_fim), (f"{prefix}loss_std", ~is_fim)]:
if mask.any():
shift_logits = logits[mask][..., :-1, :].contiguous()
shift_labels = labels[mask][..., 1:].contiguous()
loss_per_token = loss_fct(
Expand All @@ -592,11 +625,14 @@ def on_log(self, args, state, control, logs=None, **kwargs):
valid = shift_labels.view(-1) != -100
if valid.any():
logs[tag] = loss_per_token[valid].mean().item()
elif has_fim:
logs[f"{prefix}loss_fim"] = logs.get(loss_key)
else:
logs[f"{prefix}loss_std"] = logs.get(loss_key)

# Apply per-sample loss weighting if available
if sample_weights:
# Reshape to (batch_size, seq_len)
loss_per_token = loss_per_token.view(shift_labels.size(0), -1)
weighted_loss = (
loss_per_token * valid * sample_weights[mask][..., None]
).sum() / (valid * sample_weights[mask][..., None]).sum()

# Parse report_to: "swanlab,tensorboard" -> ["swanlab", "tensorboard"]
report_to = [r.strip() for r in args.report_to.split(",") if r.strip()]
Expand Down