Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/predbat/load_ml_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _init_predictor(self):
self.model_valid = True
self.model_status = "active"
self.initial_training_done = True
self.last_train_time = self.predictor.training_timestamp
else:
self.log("ML Component: Loaded model is invalid ({}), will retrain".format(reason))
self.model_status = "fallback_" + reason
Expand Down
65 changes: 65 additions & 0 deletions apps/predbat/tests/test_load_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_load_ml(my_predbat=None):
("component_publish_entity", _test_component_publish_entity, "LoadMLComponent _publish_entity method"),
("car_subtraction_direct", _test_car_subtraction_direct, "Direct car_subtraction method with interpolation and smoothing"),
("component_run_data_merge", _test_component_run_data_merge, "LoadMLComponent run() data fetch, save and merge across two runs"),
("component_init_predictor_last_train_time", _test_component_init_predictor_sets_last_train_time, "LoadMLComponent _init_predictor sets last_train_time from embedded training_timestamp"),
# ("real_data_training", _test_real_data_training, "Train on real data with chart"),
# ("pretrained_model_prediction", _test_pretrained_model_prediction, "Load pre-trained model and generate predictions with chart"),
]
Expand Down Expand Up @@ -2881,3 +2882,67 @@ def get_arg(self, key, default=None, indirect=True, combine=False, attribute=Non
print("PASS")

print(" All car_subtraction direct tests passed!")


def _test_component_init_predictor_sets_last_train_time():
"""Test that _init_predictor populates last_train_time from the model's embedded training_timestamp."""
import tempfile
from load_ml_component import LoadMLComponent

class MockBase:
"""Minimal base required by ComponentBase and LoadMLComponent.initialize."""

def __init__(self, config_root):
self.prefix = "predbat"
self.config_root = config_root
self.now_utc = datetime.now(timezone.utc)
self.midnight_utc = self.now_utc.replace(hour=0, minute=0, second=0, microsecond=0)
self.minutes_now = (self.now_utc - self.midnight_utc).seconds // 60
self.local_tz = timezone.utc
self.args = {}
self.log_messages = []

def log(self, msg):
self.log_messages.append(msg)

def get_arg(self, key, default=None, indirect=True, combine=False, attribute=None, index=None, domain=None, can_override=True, required_unit=None):
return {
"load_today": ["sensor.load_today"],
"load_power": None,
"car_charging_energy": None,
"load_scaling": 1.0,
"car_charging_energy_scale": 1.0,
}.get(key, default)

with tempfile.TemporaryDirectory() as tmpdir:
model_path = os.path.join(tmpdir, "predbat_ml_model.npz")
np.random.seed(42)
now_utc = datetime.now(timezone.utc)
load_data = _create_synthetic_load_data(n_days=7, now_utc=now_utc)

# --- Part 1: model with a known timestamp → last_train_time set to that timestamp ---
predictor = LoadPredictor(learning_rate=0.01)
predictor.train(load_data, now_utc, is_initial=True, epochs=2, time_decay_days=7)
known_timestamp = predictor.training_timestamp
assert known_timestamp is not None, "Predictor should have training_timestamp after training"
predictor.validation_mae = 0.1 # Force a passing validation score
predictor.save(model_path)

component = LoadMLComponent(MockBase(config_root=tmpdir), load_ml_enable=True)

assert component.last_train_time is not None, "last_train_time should be set after loading a cached model"
assert component.last_train_time == known_timestamp, "last_train_time should equal the model's embedded training_timestamp (not filesystem mtime); " "got {} expected {}".format(component.last_train_time, known_timestamp)
assert component.model_valid is True, "Model should be marked valid"
assert component.initial_training_done is True, "initial_training_done should be True"

# --- Part 2: model with no timestamp → last_train_time stays None (triggers retrain) ---
predictor2 = LoadPredictor(learning_rate=0.01)
predictor2.train(load_data, now_utc, is_initial=True, epochs=2, time_decay_days=7)
predictor2.training_timestamp = None # Simulate a pre-timestamp model
predictor2.validation_mae = 0.1
predictor2.save(model_path)

component2 = LoadMLComponent(MockBase(config_root=tmpdir), load_ml_enable=True)

assert component2.last_train_time is None, "last_train_time should remain None when model has no embedded timestamp (triggers safe retrain)"
assert component2.model_valid is True, "Model without timestamp should still be considered valid by is_valid()"
Loading