Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions spras/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy as copy
import functools
import itertools as it
import re
import warnings
from pathlib import Path
from typing import Any
Expand All @@ -28,6 +29,15 @@
from spras.config.util import AlgorithmName, get_valid_algorithm_names
from spras.util import LoosePathLike, NpHashEncoder, hash_params_sha1_base32

# Modify YAML float specification when a YAML file is parsed directly
# Default requires decimal point for scientific notation: https://yaml.org/type/float.html

@tristan-f-r tristan-f-r Jun 14, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the YAML specification itself doesn't support this, Pydantic handles coercion of floats from strings just fine, including with scientific notation. This error was instead caused by my assumption that this coercion would happen before any BeforeValidator was run, which is notably the opposite of the characterization of BeforeValidator: I've proposed an alternative fix instead.

@ntalluri ntalluri Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative fix is very clean, I am curious if it will pass the tests Tony set up.

# add_implicit_resolver adds the regex pattern to the float type: https://pyyaml.org/wiki/PyYAMLDocumentation
yaml.SafeLoader.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'^[-+]?([0-9]+(\.[0-9]*)?|\.[0-9]+)[eE][-+]?[0-9]+$'),
list('-+0123456789.')
)

config = None

# This will get called in the Snakefile, instantiating the singleton with the raw config
Expand All @@ -39,6 +49,24 @@ def init_from_file(filepath):
global config
config = Config.from_file(filepath)

def sanitize_scientific_notation(data: Any) -> Any:
"""
Recursively checks a YAML configuration file parsed externally by Snakemake to convert scientific notation strings
back to floats.

Default YAML 1.1 requires decimal point for scientific notation: https://yaml.org/type/float.html
"""
if isinstance(data, dict):
return {k: sanitize_scientific_notation(v) for k, v in data.items()}
elif isinstance(data, list):
return [sanitize_scientific_notation(item) for item in data]
elif isinstance(data, str):
if re.match(r'^[-+]?([0-9]+(\.[0-9]*)?|\.[0-9]+)[eE][-+]?[0-9]+$', data):
try:
return float(data)
except ValueError:
return data
return data

class Config:
def __init__(self, raw_config: dict[str, Any]):
Expand All @@ -47,6 +75,8 @@ def __init__(self, raw_config: dict[str, Any]):
if raw_config == {}:
raise ValueError("Config file cannot be empty. Use --configfile <filename> to set a config file.")

raw_config = sanitize_scientific_notation(raw_config)

parsed_raw_config = RawConfig.model_validate(raw_config)

# Member vars populated by process_config. Any values that don't have quick initial values are set to None
Expand Down
82 changes: 82 additions & 0 deletions test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Iterable

import pytest
import yaml
from pydantic import BaseModel

import spras.config.config as config
Expand Down Expand Up @@ -451,3 +452,84 @@ def test_eval_summary_coupling(self, eval_include, summary_include, expected_eva
assert config.config.analysis_include_evaluation == expected_eval
assert config.config.analysis_include_summary == expected_summary


@pytest.mark.parametrize("numeric_string, expected_float", [
(".5", 0.5),
("6.", 6.0),
("3.1", 3.1),
("-4.7", -4.7),
("1e-3", 0.001),
("1E-3", 0.001),
("2.5e4", 25000.0),
("2e4", 20000.0),
("+5e+2", 500.0),
("-3e-2", -0.03),
])
def test_scientific_notation_variants(self, numeric_string, expected_float):
"""
Ensures that the PyYAML SafeLoader accurately converts all standard
scientific notation formats into Python floats.
Used for YAML files parsed directly.
"""
yaml_string = f"value: {numeric_string}"
# Tests the modified float imported from spras.config.config
parsed = yaml.safe_load(yaml_string)

assert isinstance(parsed["value"], float)
assert parsed["value"] == expected_float


@pytest.mark.parametrize("non_numeric_string", [
"range1e3",
"e-3", # Missing coefficient
"1e", # Missing exponent power
])
def test_scientific_notation_does_not_match_strings(self, non_numeric_string):
"""
Ensures that strings, labels, and parameters containing 'e' numbers
are correctly preserved as strings and not warped into floats.
Used for YAML files parsed directly.
"""
yaml_string = f"label: {non_numeric_string}"
# Tests the modified float imported from spras.config.config
parsed = yaml.safe_load(yaml_string)

assert isinstance(parsed["label"], str)
assert parsed["label"] == non_numeric_string

def test_sanitize_scientific_notation(self):
"""
Verifies that sanitize_scientific_notation recursively converts scientific
notation strings into floats across nested dictionaries and lists.
Used for YAML files loaded externally by Snakemake.
"""
# Does not follow the SPRAS config file syntax
nested_config = {
"f": "1e-3",
"algorithms": {
"omicsintegrator2": {
"params": {
"sweep": ["2.5e4", "5e-2"],
"x": "-6e2"
},
"nodes": "terminals"
}
},
"threads": 4
}

expected_config = {
"f": 0.001,
"algorithms": {
"omicsintegrator2": {
"params": {
"sweep": [25000.0, 0.05],
"x": -600.0
},
"nodes": "terminals"
}
},
"threads": 4
}

assert config.sanitize_scientific_notation(nested_config) == expected_config
Loading