Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions src/proxyz/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
@click.command(context_settings={'show_default': True})
@click.argument("data_files", type=click.Path(), nargs=-1)
@click.option(
"--eval_files", type=click.Path(), multiple=True, help="evaluate data files"
"--eval_files", type=str(), multiple=True,
help="Evaluation dataset(s). Format: 'name=file1,file2' for named datasets "
"(metrics traced separately per name), or plain 'file' (auto-named from filename). "
"Repeat for multiple eval sets, e.g. --eval_files casp=casp15.txt --eval_files cameo=cameo.txt",
)
@click.option(
"--dataset_name",
Expand Down Expand Up @@ -432,9 +435,30 @@ def data_generator(data_files):
)
eval_dataset = None
if args.eval_files:
eval_dataset = Dataset.from_generator(
functools.partial(data_generator, args.eval_files)
)
# Parse --eval_files into named groups: "name=file1,file2" or plain "file"
eval_groups = {}
for spec in args.eval_files:
if "=" in spec:
name, paths_str = spec.split("=", 1)
name = name.strip()
files = tuple(p.strip() for p in paths_str.split(",") if p.strip())
else:
files = (spec,)
name = os.path.splitext(os.path.basename(spec))[0]
# De-duplicate names by appending a suffix
base_name = name
counter = 2
while name in eval_groups:
name = f"{base_name}_{counter}"
counter += 1
eval_groups[name] = files

eval_dataset = {
name: Dataset.from_generator(
functools.partial(data_generator, files)
)
for name, files in eval_groups.items()
}

# Apply tokenization
text_col = args.text_column if args.dataset_name else "text"
Expand All @@ -443,13 +467,20 @@ def tokenize_dataset(dataset):
return dataset.map(tokenize_function, batched=True, remove_columns=[text_col])

train_dataset = tokenize_dataset(train_dataset)
if eval_dataset:
eval_dataset = tokenize_dataset(eval_dataset)
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {k: tokenize_dataset(v) for k, v in eval_dataset.items()}
else:
eval_dataset = tokenize_dataset(eval_dataset)

if args.verbose:
print(f"--- Train dataset ---")
print(f"Examples: {len(train_dataset):,}")
if eval_dataset:
if isinstance(eval_dataset, dict):
for name, ds in eval_dataset.items():
print(f"--- Eval dataset [{name}] ---")
print(f"Examples: {len(ds):,} Files: {eval_groups.get(name, '?')}")
elif eval_dataset:
print(f"--- Eval dataset ---")
print(f"Examples: {len(eval_dataset):,}")

Expand Down Expand Up @@ -563,10 +594,22 @@ def on_log(self, args, state, control, logs=None, **kwargs):
self.trainer._fim_cache = None
labels, logits = cache

# Detect if this is eval or training based on log keys
is_eval = "eval_loss" in logs
prefix = "eval_" if is_eval else ""
loss_key = f"{prefix}loss"
# Detect eval vs training and determine metric prefix.
# With multiple eval datasets the Trainer logs per-dataset keys
# like "eval_casp15_loss", "eval_cameo_loss"; with a single eval
# dataset it logs "eval_loss". Training logs use "loss".
if "eval_loss" in logs:
prefix = "eval_"
else:
for key in logs:
if key.startswith("eval_") and key.endswith("_loss"):
# prefix = "eval_<name>_" e.g. "eval_casp15_"
prefix = key[: -len("loss")] # keeps trailing "_"
break
else:
prefix = ""

loss_key = f"{prefix}loss" if prefix else "loss"

# Detect FIM examples: labels start with -100
is_fim = labels[:, 0] == -100
Expand Down