Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions modeling/CHANGES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Modeling Changes

## Branch: feature/modeling-improvements

---

### 1. AUC-based checkpointing (`train.py`)

**What changed:** `ModelCheckpoint` and `EarlyStopping` now monitor `val_auc` (maximise) instead of `val_loss` (minimise). Added `AUC` as a compiled metric. `ReduceLROnPlateau` stays on `val_loss` — this benefits from the smoother, more continuous signal.

**Why:** `val_loss` (binary cross-entropy) rewards confidence, not ranking quality. A model can drive loss down by being very certain about easy negatives while still fumbling borderline WORTH cases. AUC directly measures whether the model ranks WORTH above NOT_WORTH across all thresholds — a much stronger signal when sensitivity is the priority. Checkpointing on best AUC gives us the most flexibility when choosing a decision threshold that meets the safety floor.

---

### 2. Validation-set threshold optimisation (`evaluate.py`)

**What changed:** Added `find_optimal_threshold(model, val_df, image_dir, ...)` to `evaluate.py`, exported from `__init__.py`. The 0.5 default in `evaluate_baseline` is intentionally untouched — callers are expected to run `find_optimal_threshold` on the validation set first and pass the result in explicitly.

**Why:** The sigmoid output is a ranking score, not a calibrated probability. Defaulting to 0.5 ignores class imbalance and the asymmetric cost of missing a WORTH case. Instead, we sweep the ROC curve on the validation set to find the lowest threshold where sensitivity meets `WORTH_SENSITIVITY_FLOOR` (0.80) and specificity is as high as possible — catching every necessary case while keeping patient callbacks to a minimum. The test set is never touched during this step.

**Usage:**

```python
threshold = find_optimal_threshold(model, val_df, image_dir="data/images/")
results = evaluate_baseline(model, test_df, image_dir="data/images/", threshold=threshold)
```
1 change: 1 addition & 0 deletions modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
WORTH_SENSITIVITY_FLOOR,
WORTH_WEIGHT_MULTIPLIER,
)
from .evaluate import find_optimal_threshold
2 changes: 1 addition & 1 deletion modeling/baseline_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ def compute_class_weights(labels: list[int]) -> dict[int, float]:
# Extra penalty for missing a WORTH case (false reassurance risk).
weights[POSITIVE_CLASS_INDEX] *= WORTH_WEIGHT_MULTIPLIER

return {i: float(w) for i, w in enumerate(weights)}
return {i: float(w) for i, w in enumerate(weights)}
63 changes: 63 additions & 0 deletions modeling/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
classification_report,
confusion_matrix,
recall_score,
roc_curve,
)

from config.constants import INPUT_SIZE
Expand All @@ -29,6 +30,68 @@
from modeling.train import _build_dataset


def find_optimal_threshold(
model: tf.keras.Model,
val_df: pd.DataFrame,
image_dir: str,
image_col: str = "image_path",
label_col: str = "label",
input_size: tuple = INPUT_SIZE,
batch_size: int = 32,
) -> float:
"""Find the decision threshold that maximises specificity while meeting the sensitivity floor.

Sweeps the ROC curve on the validation set and picks the operating point where
WORTH_SECOND_LOOK sensitivity >= WORTH_SENSITIVITY_FLOOR and FPR (false-alarm rate)
is minimised. Falls back to 0.5 with a warning if no threshold meets the floor.

Args:
model: Trained Keras model (or path string to a saved .keras file).
val_df: Validation split DataFrame — must NOT be the test set.
image_dir: Root directory containing image files.
image_col: Column with image filenames.
label_col: Column with binary labels (int 0 or 1).
input_size: Must match the size used during training.
batch_size: Inference batch size.

Returns:
Optimal threshold float in (0, 1).
"""
if isinstance(model, str):
model = tf.keras.models.load_model(model)

val_ds = _build_dataset(
val_df, image_dir, image_col, label_col, input_size, batch_size, shuffle=False
)
true_labels = np.asarray([int(y) for y in val_df[label_col]])
probabilities = model.predict(val_ds, verbose=0).ravel()

# roc_curve returns fpr, tpr (=sensitivity), and the corresponding thresholds.
fpr, tpr, thresholds = roc_curve(true_labels, probabilities, pos_label=POSITIVE_CLASS_INDEX)

# Keep only operating points that satisfy the sensitivity floor.
valid_mask = tpr >= WORTH_SENSITIVITY_FLOOR
if not valid_mask.any():
print(
f"WARNING: No threshold achieves sensitivity >= {WORTH_SENSITIVITY_FLOOR}. "
"Falling back to 0.5. Consider retraining with stronger class weighting."
)
return 0.5

# Among valid points, choose the one with the lowest FPR (fewest false alarms).
best_idx = np.argmin(fpr[valid_mask])
optimal = float(thresholds[valid_mask][best_idx])

achieved_sensitivity = float(tpr[valid_mask][best_idx])
achieved_specificity = float(1.0 - fpr[valid_mask][best_idx])
print(
f"\nOptimal threshold: {optimal:.3f} "
f"(val sensitivity: {achieved_sensitivity:.3f}, "
f"val specificity: {achieved_specificity:.3f})"
)
return optimal


def evaluate_baseline(
model: tf.keras.Model,
test_df: pd.DataFrame,
Expand Down
8 changes: 5 additions & 3 deletions modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def train_baseline(
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss="binary_crossentropy",
metrics=["accuracy"],
metrics=["accuracy", tf.keras.metrics.AUC(name="auc")],
)

class_weights = compute_class_weights(list(train_df[label_col]))
Expand Down Expand Up @@ -155,12 +155,14 @@ def _build_callbacks(checkpoint_dir: str) -> list:
return [
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, "best.keras"),
monitor="val_loss",
monitor="val_auc",
mode="max",
save_best_only=True,
verbose=1,
),
tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
monitor="val_auc",
mode="max",
patience=7,
restore_best_weights=True,
verbose=1,
Expand Down