diff --git a/apps/predbat/load_ml_component.py b/apps/predbat/load_ml_component.py index d0bc62358..859e23836 100644 --- a/apps/predbat/load_ml_component.py +++ b/apps/predbat/load_ml_component.py @@ -155,6 +155,7 @@ def _init_predictor(self): self.model_valid = True self.model_status = "active" self.initial_training_done = True + self.last_train_time = self.predictor.training_timestamp else: self.log("ML Component: Loaded model is invalid ({}), will retrain".format(reason)) self.model_status = "fallback_" + reason diff --git a/apps/predbat/tests/test_load_ml.py b/apps/predbat/tests/test_load_ml.py index 2a5ab241f..c8eea6dd7 100644 --- a/apps/predbat/tests/test_load_ml.py +++ b/apps/predbat/tests/test_load_ml.py @@ -61,6 +61,7 @@ def test_load_ml(my_predbat=None): ("component_publish_entity", _test_component_publish_entity, "LoadMLComponent _publish_entity method"), ("car_subtraction_direct", _test_car_subtraction_direct, "Direct car_subtraction method with interpolation and smoothing"), ("component_run_data_merge", _test_component_run_data_merge, "LoadMLComponent run() data fetch, save and merge across two runs"), + ("component_init_predictor_last_train_time", _test_component_init_predictor_sets_last_train_time, "LoadMLComponent _init_predictor sets last_train_time from embedded training_timestamp"), # ("real_data_training", _test_real_data_training, "Train on real data with chart"), # ("pretrained_model_prediction", _test_pretrained_model_prediction, "Load pre-trained model and generate predictions with chart"), ] @@ -2881,3 +2882,67 @@ def get_arg(self, key, default=None, indirect=True, combine=False, attribute=Non print("PASS") print(" All car_subtraction direct tests passed!") + + +def _test_component_init_predictor_sets_last_train_time(): + """Test that _init_predictor populates last_train_time from the model's embedded training_timestamp.""" + import tempfile + from load_ml_component import LoadMLComponent + + class MockBase: + """Minimal base required by ComponentBase and LoadMLComponent.initialize.""" + + def __init__(self, config_root): + self.prefix = "predbat" + self.config_root = config_root + self.now_utc = datetime.now(timezone.utc) + self.midnight_utc = self.now_utc.replace(hour=0, minute=0, second=0, microsecond=0) + self.minutes_now = (self.now_utc - self.midnight_utc).seconds // 60 + self.local_tz = timezone.utc + self.args = {} + self.log_messages = [] + + def log(self, msg): + self.log_messages.append(msg) + + def get_arg(self, key, default=None, indirect=True, combine=False, attribute=None, index=None, domain=None, can_override=True, required_unit=None): + return { + "load_today": ["sensor.load_today"], + "load_power": None, + "car_charging_energy": None, + "load_scaling": 1.0, + "car_charging_energy_scale": 1.0, + }.get(key, default) + + with tempfile.TemporaryDirectory() as tmpdir: + model_path = os.path.join(tmpdir, "predbat_ml_model.npz") + np.random.seed(42) + now_utc = datetime.now(timezone.utc) + load_data = _create_synthetic_load_data(n_days=7, now_utc=now_utc) + + # --- Part 1: model with a known timestamp → last_train_time set to that timestamp --- + predictor = LoadPredictor(learning_rate=0.01) + predictor.train(load_data, now_utc, is_initial=True, epochs=2, time_decay_days=7) + known_timestamp = predictor.training_timestamp + assert known_timestamp is not None, "Predictor should have training_timestamp after training" + predictor.validation_mae = 0.1 # Force a passing validation score + predictor.save(model_path) + + component = LoadMLComponent(MockBase(config_root=tmpdir), load_ml_enable=True) + + assert component.last_train_time is not None, "last_train_time should be set after loading a cached model" + assert component.last_train_time == known_timestamp, "last_train_time should equal the model's embedded training_timestamp (not filesystem mtime); " "got {} expected {}".format(component.last_train_time, known_timestamp) + assert component.model_valid is True, "Model should be marked valid" + assert component.initial_training_done is True, "initial_training_done should be True" + + # --- Part 2: model with no timestamp → last_train_time stays None (triggers retrain) --- + predictor2 = LoadPredictor(learning_rate=0.01) + predictor2.train(load_data, now_utc, is_initial=True, epochs=2, time_decay_days=7) + predictor2.training_timestamp = None # Simulate a pre-timestamp model + predictor2.validation_mae = 0.1 + predictor2.save(model_path) + + component2 = LoadMLComponent(MockBase(config_root=tmpdir), load_ml_enable=True) + + assert component2.last_train_time is None, "last_train_time should remain None when model has no embedded timestamp (triggers safe retrain)" + assert component2.model_valid is True, "Model without timestamp should still be considered valid by is_valid()"