diff --git a/spras/config/config.py b/spras/config/config.py index ebf10faa..bc427ed6 100644 --- a/spras/config/config.py +++ b/spras/config/config.py @@ -15,6 +15,7 @@ import copy as copy import functools import itertools as it +import re import warnings from pathlib import Path from typing import Any @@ -28,6 +29,15 @@ from spras.config.util import AlgorithmName, get_valid_algorithm_names from spras.util import LoosePathLike, NpHashEncoder, hash_params_sha1_base32 +# Modify YAML float specification when a YAML file is parsed directly +# Default requires decimal point for scientific notation: https://yaml.org/type/float.html +# add_implicit_resolver adds the regex pattern to the float type: https://pyyaml.org/wiki/PyYAMLDocumentation +yaml.SafeLoader.add_implicit_resolver( + 'tag:yaml.org,2002:float', + re.compile(r'^[-+]?([0-9]+(\.[0-9]*)?|\.[0-9]+)[eE][-+]?[0-9]+$'), + list('-+0123456789.') +) + config = None # This will get called in the Snakefile, instantiating the singleton with the raw config @@ -39,6 +49,24 @@ def init_from_file(filepath): global config config = Config.from_file(filepath) +def sanitize_scientific_notation(data: Any) -> Any: + """ + Recursively checks a YAML configuration file parsed externally by Snakemake to convert scientific notation strings + back to floats. + + Default YAML 1.1 requires decimal point for scientific notation: https://yaml.org/type/float.html + """ + if isinstance(data, dict): + return {k: sanitize_scientific_notation(v) for k, v in data.items()} + elif isinstance(data, list): + return [sanitize_scientific_notation(item) for item in data] + elif isinstance(data, str): + if re.match(r'^[-+]?([0-9]+(\.[0-9]*)?|\.[0-9]+)[eE][-+]?[0-9]+$', data): + try: + return float(data) + except ValueError: + return data + return data class Config: def __init__(self, raw_config: dict[str, Any]): @@ -47,6 +75,8 @@ def __init__(self, raw_config: dict[str, Any]): if raw_config == {}: raise ValueError("Config file cannot be empty. Use --configfile to set a config file.") + raw_config = sanitize_scientific_notation(raw_config) + parsed_raw_config = RawConfig.model_validate(raw_config) # Member vars populated by process_config. Any values that don't have quick initial values are set to None diff --git a/test/test_config.py b/test/test_config.py index cf4cdf0f..d1b47936 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -3,6 +3,7 @@ from typing import Iterable import pytest +import yaml from pydantic import BaseModel import spras.config.config as config @@ -451,3 +452,84 @@ def test_eval_summary_coupling(self, eval_include, summary_include, expected_eva assert config.config.analysis_include_evaluation == expected_eval assert config.config.analysis_include_summary == expected_summary + + @pytest.mark.parametrize("numeric_string, expected_float", [ + (".5", 0.5), + ("6.", 6.0), + ("3.1", 3.1), + ("-4.7", -4.7), + ("1e-3", 0.001), + ("1E-3", 0.001), + ("2.5e4", 25000.0), + ("2e4", 20000.0), + ("+5e+2", 500.0), + ("-3e-2", -0.03), + ]) + def test_scientific_notation_variants(self, numeric_string, expected_float): + """ + Ensures that the PyYAML SafeLoader accurately converts all standard + scientific notation formats into Python floats. + Used for YAML files parsed directly. + """ + yaml_string = f"value: {numeric_string}" + # Tests the modified float imported from spras.config.config + parsed = yaml.safe_load(yaml_string) + + assert isinstance(parsed["value"], float) + assert parsed["value"] == expected_float + + + @pytest.mark.parametrize("non_numeric_string", [ + "range1e3", + "e-3", # Missing coefficient + "1e", # Missing exponent power + ]) + def test_scientific_notation_does_not_match_strings(self, non_numeric_string): + """ + Ensures that strings, labels, and parameters containing 'e' numbers + are correctly preserved as strings and not warped into floats. + Used for YAML files parsed directly. + """ + yaml_string = f"label: {non_numeric_string}" + # Tests the modified float imported from spras.config.config + parsed = yaml.safe_load(yaml_string) + + assert isinstance(parsed["label"], str) + assert parsed["label"] == non_numeric_string + + def test_sanitize_scientific_notation(self): + """ + Verifies that sanitize_scientific_notation recursively converts scientific + notation strings into floats across nested dictionaries and lists. + Used for YAML files loaded externally by Snakemake. + """ + # Does not follow the SPRAS config file syntax + nested_config = { + "f": "1e-3", + "algorithms": { + "omicsintegrator2": { + "params": { + "sweep": ["2.5e4", "5e-2"], + "x": "-6e2" + }, + "nodes": "terminals" + } + }, + "threads": 4 + } + + expected_config = { + "f": 0.001, + "algorithms": { + "omicsintegrator2": { + "params": { + "sweep": [25000.0, 0.05], + "x": -600.0 + }, + "nodes": "terminals" + } + }, + "threads": 4 + } + + assert config.sanitize_scientific_notation(nested_config) == expected_config