From b9d978cecbefa248e939bdfe1c2b0368a4dddd7d Mon Sep 17 00:00:00 2001 From: chungongyu Date: Wed, 1 Jul 2026 11:41:22 +0800 Subject: [PATCH] refactor: eval with weighted samples --- src/proxyz/train.py | 104 +++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 34 deletions(-) diff --git a/src/proxyz/train.py b/src/proxyz/train.py index d80bb2f..88872f2 100644 --- a/src/proxyz/train.py +++ b/src/proxyz/train.py @@ -217,6 +217,12 @@ help="Clustering files for cluster-based sampling. Each file has two columns: " "cluster_id and data_row_id. Sampling weight = n / (1 + log(n)) where n is cluster size.", ) +@click.option( + "--eval_cluster_files", + type=click.Path(), + multiple=True, + help="Clustering files for eval per-sample loss weighting. Same format as --cluster_files.", +) @click.option("-v", "--verbose", is_flag=True, help="verbose output.") def main(**args): args = dict2object(**args) @@ -453,12 +459,17 @@ def tokenize_dataset(dataset): print(f"--- Eval dataset ---") print(f"Examples: {len(eval_dataset):,}") - # Load cluster information and compute sampling weights if cluster files provided - train_sampler = None - if args.cluster_files: + def _compute_cluster_weights(cluster_files, dataset_len): + """Load cluster files and compute per-sample weights. + + Returns (cluster_weights, cluster_map): + cluster_weights: dict mapping cluster_id -> weight (1 / (1 + log(n))) + cluster_map: dict mapping data_row_id -> cluster_id + """ + # Load all cluster files and build mapping: data_row_id -> cluster_id cluster_map = {} # data_row_id -> cluster_id - for cluster_file in args.cluster_files: + for cluster_file in cluster_files: with open(cluster_file, 'r') as f: for line in f: parts = line.strip().split() @@ -466,58 +477,79 @@ def tokenize_dataset(dataset): cluster_id = parts[0] data_row_id = int(parts[1]) cluster_map[data_row_id] = cluster_id - + # Count cluster sizes cluster_sizes = defaultdict(int) for data_row_id, cluster_id in cluster_map.items(): - if data_row_id < len(train_dataset): # Only count valid indices + if data_row_id < dataset_len: # Only count valid indices cluster_sizes[cluster_id] += 1 - + # Compute sampling weights: n / (1 + log(n)) for each cluster cluster_weights = {} for cluster_id, n in cluster_sizes.items(): cluster_weights[cluster_id] = 1 / (1 + math.log(n)) - + # Assign weights to each sample - sample_weights = [] - for i in range(len(train_dataset)): - if i in cluster_map: - cluster_id = cluster_map[i] - weight = cluster_weights[cluster_id] - sample_weights.append(weight) - else: - # If no cluster info, use uniform weight (1.0) - sample_weights.append(1.0) - + sample_weights = [ + cluster_weights.get(cluster_map[i], 1.0) if i in cluster_map else 1.0 + for i in range(dataset_len) + ] + + return sample_weights, cluster_map, cluster_sizes + + + # Load cluster information and compute sampling weights if cluster files provided + train_sampler = None + if args.cluster_files: + sample_weights, _, cluster_sizes = _compute_cluster_weights( + args.cluster_files, len(train_dataset) + ) + # Create WeightedRandomSampler train_sampler = WeightedRandomSampler( weights=sample_weights, num_samples=len(sample_weights), replacement=True ) - + if args.verbose: print(f"--- Cluster-based sampling ---") print(f"Clusters: {len(cluster_sizes):,}") print(f"Samples with cluster info: {sum(1 for w in sample_weights if w != 1.0):,}") + # Load eval cluster weights for per-sample loss weighting during evaluation + if args.eval_cluster_files and eval_dataset: + eval_sample_weights, _, eval_cluster_sizes = _compute_cluster_weights( + args.eval_cluster_files, len(eval_dataset) + ) + eval_dataset = eval_dataset.add_column("sample_weight", eval_sample_weights) + if args.verbose: + print(f"--- Eval cluster weighting ---") + print(f"Eval clusters: {len(eval_cluster_sizes):,}") + print(f"Eval samples with weight info: {sum(1 for w in eval_sample_weights if w != 1.0):,}") + # Data collator that pads input_ids, attention_mask, and labels uniformly pad_token_id = tokenizer.pad_token_id def data_collator(examples): max_len = max(len(ex["input_ids"]) for ex in examples) input_ids, attention_mask, labels = [], [], [] + sample_weights = [] for ex in examples: pad_len = max_len - len(ex["input_ids"]) input_ids.append(ex["input_ids"] + [pad_token_id] * pad_len) attention_mask.append(ex["attention_mask"] + [0] * pad_len) label_seq = ex.get("labels", ex["input_ids"]) labels.append(label_seq + [-100] * pad_len) - return { + sample_weights.append(ex.get("sample_weight", 1.0)) + batch = { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } + if any(w != 1.0 for w in sample_weights): + batch["sample_weights"] = torch.tensor(sample_weights, dtype=torch.float32) + return batch # ========================================== # 4. TRAINING ARGUMENTS & EXECUTION @@ -527,14 +559,18 @@ def data_collator(examples): class FIMTrainer(Trainer): def __init__(self, train_sampler=None, **kwargs): super().__init__(**kwargs) + self.train_sampler = train_sampler - + def _get_train_sampler(self, train_dataset: Dataset = None): if self.train_sampler is not None: return self.train_sampler return super()._get_train_sampler(train_dataset) - + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # Extract sample weights for eval (not a model input) + sample_weights = inputs.pop("sample_weights", None) + # Always cache data for FIM loss tracking (training and eval) # Cache data BEFORE calling super (which may modify inputs) labels = inputs["labels"].detach().clone() @@ -545,7 +581,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # Cache logits after forward pass - self._fim_cache = (labels, outputs.logits.detach().clone()) + self._fim_cache = (labels, outputs.logits.detach().clone(), sample_weights) return (loss, outputs) if return_outputs else loss @@ -561,7 +597,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): if cache is None: return self.trainer._fim_cache = None - labels, logits = cache + labels, logits, sample_weights = cache # Detect if this is eval or training based on log keys is_eval = "eval_loss" in logs @@ -577,12 +613,9 @@ def on_log(self, args, state, control, logs=None, **kwargs): logs[f"{prefix}n_fim"] = n_fim logs[f"{prefix}n_std"] = n_std - has_fim = is_fim.any().item() - has_std = (~is_fim).any().item() - - if has_fim and has_std: - loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - for tag, mask in [(f"{prefix}loss_fim", is_fim), (f"{prefix}loss_std", ~is_fim)]: + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + for tag, mask in [(f"{prefix}loss_fim", is_fim), (f"{prefix}loss_std", ~is_fim)]: + if mask.any(): shift_logits = logits[mask][..., :-1, :].contiguous() shift_labels = labels[mask][..., 1:].contiguous() loss_per_token = loss_fct( @@ -592,11 +625,14 @@ def on_log(self, args, state, control, logs=None, **kwargs): valid = shift_labels.view(-1) != -100 if valid.any(): logs[tag] = loss_per_token[valid].mean().item() - elif has_fim: - logs[f"{prefix}loss_fim"] = logs.get(loss_key) - else: - logs[f"{prefix}loss_std"] = logs.get(loss_key) + # Apply per-sample loss weighting if available + if sample_weights: + # Reshape to (batch_size, seq_len) + loss_per_token = loss_per_token.view(shift_labels.size(0), -1) + weighted_loss = ( + loss_per_token * valid * sample_weights[mask][..., None] + ).sum() / (valid * sample_weights[mask][..., None]).sum() # Parse report_to: "swanlab,tensorboard" -> ["swanlab", "tensorboard"] report_to = [r.strip() for r in args.report_to.split(",") if r.strip()]