diff --git a/src/proxyz/train.py b/src/proxyz/train.py index d80bb2f..47c7ed8 100644 --- a/src/proxyz/train.py +++ b/src/proxyz/train.py @@ -26,7 +26,10 @@ @click.command(context_settings={'show_default': True}) @click.argument("data_files", type=click.Path(), nargs=-1) @click.option( - "--eval_files", type=click.Path(), multiple=True, help="evaluate data files" + "--eval_files", type=str(), multiple=True, + help="Evaluation dataset(s). Format: 'name=file1,file2' for named datasets " + "(metrics traced separately per name), or plain 'file' (auto-named from filename). " + "Repeat for multiple eval sets, e.g. --eval_files casp=casp15.txt --eval_files cameo=cameo.txt", ) @click.option( "--dataset_name", @@ -432,9 +435,30 @@ def data_generator(data_files): ) eval_dataset = None if args.eval_files: - eval_dataset = Dataset.from_generator( - functools.partial(data_generator, args.eval_files) - ) + # Parse --eval_files into named groups: "name=file1,file2" or plain "file" + eval_groups = {} + for spec in args.eval_files: + if "=" in spec: + name, paths_str = spec.split("=", 1) + name = name.strip() + files = tuple(p.strip() for p in paths_str.split(",") if p.strip()) + else: + files = (spec,) + name = os.path.splitext(os.path.basename(spec))[0] + # De-duplicate names by appending a suffix + base_name = name + counter = 2 + while name in eval_groups: + name = f"{base_name}_{counter}" + counter += 1 + eval_groups[name] = files + + eval_dataset = { + name: Dataset.from_generator( + functools.partial(data_generator, files) + ) + for name, files in eval_groups.items() + } # Apply tokenization text_col = args.text_column if args.dataset_name else "text" @@ -443,13 +467,20 @@ def tokenize_dataset(dataset): return dataset.map(tokenize_function, batched=True, remove_columns=[text_col]) train_dataset = tokenize_dataset(train_dataset) - if eval_dataset: - eval_dataset = tokenize_dataset(eval_dataset) + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = {k: tokenize_dataset(v) for k, v in eval_dataset.items()} + else: + eval_dataset = tokenize_dataset(eval_dataset) if args.verbose: print(f"--- Train dataset ---") print(f"Examples: {len(train_dataset):,}") - if eval_dataset: + if isinstance(eval_dataset, dict): + for name, ds in eval_dataset.items(): + print(f"--- Eval dataset [{name}] ---") + print(f"Examples: {len(ds):,} Files: {eval_groups.get(name, '?')}") + elif eval_dataset: print(f"--- Eval dataset ---") print(f"Examples: {len(eval_dataset):,}") @@ -563,10 +594,22 @@ def on_log(self, args, state, control, logs=None, **kwargs): self.trainer._fim_cache = None labels, logits = cache - # Detect if this is eval or training based on log keys - is_eval = "eval_loss" in logs - prefix = "eval_" if is_eval else "" - loss_key = f"{prefix}loss" + # Detect eval vs training and determine metric prefix. + # With multiple eval datasets the Trainer logs per-dataset keys + # like "eval_casp15_loss", "eval_cameo_loss"; with a single eval + # dataset it logs "eval_loss". Training logs use "loss". + if "eval_loss" in logs: + prefix = "eval_" + else: + for key in logs: + if key.startswith("eval_") and key.endswith("_loss"): + # prefix = "eval__" e.g. "eval_casp15_" + prefix = key[: -len("loss")] # keeps trailing "_" + break + else: + prefix = "" + + loss_key = f"{prefix}loss" if prefix else "loss" # Detect FIM examples: labels start with -100 is_fim = labels[:, 0] == -100