From 6816ad4f3c2232bf2b61736f8aac4a760e1396d2 Mon Sep 17 00:00:00 2001 From: Trefor Southwell Date: Tue, 16 Jun 2026 20:26:03 +0100 Subject: [PATCH 1/3] fix(ml): set last_train_time from model's embedded training_timestamp on load When a cached model is loaded at startup, last_train_time was left None, causing the training-age check to treat the model as infinitely old and force an immediate retrain. Use predictor.training_timestamp (already restored from the .npz metadata by LoadPredictor.load) rather than the filesystem mtime, which can be wrong after file copies or backup restores. If training_timestamp is None (pre-dating that field), last_train_time stays None and a retrain correctly triggers as a safe fallback. Co-Authored-By: Claude Sonnet 4.6 --- apps/predbat/load_ml_component.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/predbat/load_ml_component.py b/apps/predbat/load_ml_component.py index d0bc62358..859e23836 100644 --- a/apps/predbat/load_ml_component.py +++ b/apps/predbat/load_ml_component.py @@ -155,6 +155,7 @@ def _init_predictor(self): self.model_valid = True self.model_status = "active" self.initial_training_done = True + self.last_train_time = self.predictor.training_timestamp else: self.log("ML Component: Loaded model is invalid ({}), will retrain".format(reason)) self.model_status = "fallback_" + reason From 2d20759680ccad86d748ad6dc8f429063a28eb52 Mon Sep 17 00:00:00 2001 From: Trefor Southwell Date: Tue, 16 Jun 2026 20:29:42 +0100 Subject: [PATCH 2/3] Add test --- apps/predbat/tests/test_load_ml.py | 68 ++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/apps/predbat/tests/test_load_ml.py b/apps/predbat/tests/test_load_ml.py index 2a5ab241f..db49c4ca1 100644 --- a/apps/predbat/tests/test_load_ml.py +++ b/apps/predbat/tests/test_load_ml.py @@ -61,6 +61,7 @@ def test_load_ml(my_predbat=None): ("component_publish_entity", _test_component_publish_entity, "LoadMLComponent _publish_entity method"), ("car_subtraction_direct", _test_car_subtraction_direct, "Direct car_subtraction method with interpolation and smoothing"), ("component_run_data_merge", _test_component_run_data_merge, "LoadMLComponent run() data fetch, save and merge across two runs"), + ("component_init_predictor_last_train_time", _test_component_init_predictor_sets_last_train_time, "LoadMLComponent _init_predictor sets last_train_time from embedded training_timestamp"), # ("real_data_training", _test_real_data_training, "Train on real data with chart"), # ("pretrained_model_prediction", _test_pretrained_model_prediction, "Load pre-trained model and generate predictions with chart"), ] @@ -2881,3 +2882,70 @@ def get_arg(self, key, default=None, indirect=True, combine=False, attribute=Non print("PASS") print(" All car_subtraction direct tests passed!") + + +def _test_component_init_predictor_sets_last_train_time(): + """Test that _init_predictor populates last_train_time from the model's embedded training_timestamp.""" + import tempfile + from load_ml_component import LoadMLComponent + + class MockBase: + """Minimal base required by ComponentBase and LoadMLComponent.initialize.""" + + def __init__(self, config_root): + self.prefix = "predbat" + self.config_root = config_root + self.now_utc = datetime.now(timezone.utc) + self.midnight_utc = self.now_utc.replace(hour=0, minute=0, second=0, microsecond=0) + self.minutes_now = (self.now_utc - self.midnight_utc).seconds // 60 + self.local_tz = timezone.utc + self.args = {} + self.log_messages = [] + + def log(self, msg): + self.log_messages.append(msg) + + def get_arg(self, key, default=None, indirect=True, combine=False, attribute=None, index=None, domain=None, can_override=True, required_unit=None): + return { + "load_today": ["sensor.load_today"], + "load_power": None, + "car_charging_energy": None, + "load_scaling": 1.0, + "car_charging_energy_scale": 1.0, + }.get(key, default) + + with tempfile.TemporaryDirectory() as tmpdir: + model_path = os.path.join(tmpdir, "predbat_ml_model.npz") + np.random.seed(42) + now_utc = datetime.now(timezone.utc) + load_data = _create_synthetic_load_data(n_days=7, now_utc=now_utc) + + # --- Part 1: model with a known timestamp → last_train_time set to that timestamp --- + predictor = LoadPredictor(learning_rate=0.01) + predictor.train(load_data, now_utc, is_initial=True, epochs=2, time_decay_days=7) + known_timestamp = predictor.training_timestamp + assert known_timestamp is not None, "Predictor should have training_timestamp after training" + predictor.validation_mae = 0.1 # Force a passing validation score + predictor.save(model_path) + + component = LoadMLComponent(MockBase(config_root=tmpdir), load_ml_enable=True) + + assert component.last_train_time is not None, "last_train_time should be set after loading a cached model" + assert component.last_train_time == known_timestamp, ( + "last_train_time should equal the model's embedded training_timestamp (not filesystem mtime); " + "got {} expected {}".format(component.last_train_time, known_timestamp) + ) + assert component.model_valid is True, "Model should be marked valid" + assert component.initial_training_done is True, "initial_training_done should be True" + + # --- Part 2: model with no timestamp → last_train_time stays None (triggers retrain) --- + predictor2 = LoadPredictor(learning_rate=0.01) + predictor2.train(load_data, now_utc, is_initial=True, epochs=2, time_decay_days=7) + predictor2.training_timestamp = None # Simulate a pre-timestamp model + predictor2.validation_mae = 0.1 + predictor2.save(model_path) + + component2 = LoadMLComponent(MockBase(config_root=tmpdir), load_ml_enable=True) + + assert component2.last_train_time is None, "last_train_time should remain None when model has no embedded timestamp (triggers safe retrain)" + assert component2.model_valid is True, "Model without timestamp should still be considered valid by is_valid()" From 95cd4dec6ba67a05cead2083d1c8961cf13741ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Tue, 16 Jun 2026 19:31:42 +0000 Subject: [PATCH 3/3] [pre-commit.ci lite] apply automatic fixes --- apps/predbat/tests/test_load_ml.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/apps/predbat/tests/test_load_ml.py b/apps/predbat/tests/test_load_ml.py index db49c4ca1..c8eea6dd7 100644 --- a/apps/predbat/tests/test_load_ml.py +++ b/apps/predbat/tests/test_load_ml.py @@ -2931,10 +2931,7 @@ def get_arg(self, key, default=None, indirect=True, combine=False, attribute=Non component = LoadMLComponent(MockBase(config_root=tmpdir), load_ml_enable=True) assert component.last_train_time is not None, "last_train_time should be set after loading a cached model" - assert component.last_train_time == known_timestamp, ( - "last_train_time should equal the model's embedded training_timestamp (not filesystem mtime); " - "got {} expected {}".format(component.last_train_time, known_timestamp) - ) + assert component.last_train_time == known_timestamp, "last_train_time should equal the model's embedded training_timestamp (not filesystem mtime); " "got {} expected {}".format(component.last_train_time, known_timestamp) assert component.model_valid is True, "Model should be marked valid" assert component.initial_training_done is True, "initial_training_done should be True"