Dictionary-encode coordinate columns#217
Conversation
|
This PR follows @jayendra13's idea in zarr-datafusion. |
|
Thanks — added a "Prior art" note to the description crediting @jayendra13's dictionary-encoding idea in zarr-datafusion. Happy to point at a specific file/commit there if you'd like a more precise reference. Generated by Claude Code |
alxmrs
left a comment
There was a problem hiding this comment.
LGTM so far. Will have an agent review this too.
|
Additional optimisation if you need it. |
|
Thanks @jayendra13 — good call. The PR already picked the index width adaptively per coordinate, but it capped at int32, which (as your doc notes) would silently overflow the index for a >2.1B-cardinality axis. Added the Generated by Claude Code |
|
Hey @ghostiee-11 (and Jayendra if available) may I have your review on this PR? |
ghostiee-11
left a comment
There was a problem hiding this comment.
Pulled the branch and ran it against datafusion 54. Mostly good, but test_streaming_aggregation_does_not_explode fails for me every time:
DataFusion error: Arrow error: Dictionary key bigger than the key type
Boiled it down to float32 coordinates. This crashes, flip the coords to float64 and it passes:
import numpy as np, xarray as xr
from xarray_sql import XarrayContext
ds = xr.Dataset(
{"air": (("time", "lat", "lon"), np.random.rand(2920, 25, 53).astype("float32"))},
coords={
"time": np.arange(2920),
"lat": np.linspace(-90, 90, 25).astype("float32"),
"lon": np.linspace(0, 359, 53).astype("float32"),
},
)
ctx = XarrayContext()
ctx.from_dataset("air", ds, chunks={"time": 24})
ctx.sql("SELECT lat, lon, AVG(air) AS m FROM air GROUP BY lat, lon").to_pandas()One partition is fine, ~120 blows up, so it's float32 keys plus enough batches. _coord_index_type picks the key width from one dimension's size (int8 for lat=25), which is right for a single read, but the grouped aggregate builds a Dictionary(_, Float32) across batches and the index runs past int8 instead of unifying. jayendra13's doc only covers single-read width too, so the prior art doesn't help here.
Real grids are almost always float32 (air_temperature, ERA5) and GROUP BY lat, lon is the main thing people do, so this one bites. The test that catches it pulls air_temperature over the network, so I'm guessing CI skipped it and that's why it looked green. What datafusion/arrow versions did you run against?
Smaller thing: with show_statistics on, the scan reports Bytes=Exact(40000) but the encoded coords are int8 indices, so the real size is ~8x smaller. Fine as an upper bound, just not "exact" anymore after #201.
Rest holds up: pruning still cuts WHERE time = 5 from 2000 rows to 200, min/max decode through the dictionary, round-trip is exact, and the int8/16/32/64 boundaries are right.
|
Thanks @ghostiee-11 — this was a real bug and a great catch. Fixed in 6f57863. Root cause (exactly your diagnosis): a narrow key can't survive DataFusion concatenating the per-batch coordinate dictionaries across a streaming aggregate. arrow only sometimes unifies (it's a size heuristic in Fix — make it overflow-proof and only encode where it actually pays:
Honest consequence on the headline: the big multiples came from the narrow keys, which were the unsafe part. Safe encoding is 2× on 8-byte coordinate columns and ~1.33× on a realistic float32 ERA5 grid (only the On the stats point — good call, fixed too: Generated by Claude Code |
A dense grid repeats every coordinate value across the whole partition (a chunk of shape (time, lat, lon) carries each latitude time×lon times). The reader materialized coordinate columns as dense, fully-repeated Arrow arrays, so every GROUP BY / JOIN on a coordinate re-hashed a hugely redundant column and the pivot moved far more bytes than the data itself. Encode coordinate columns as Arrow dictionaries: the distinct coordinate values are the dictionary and the strided per-row indices we already compute are the dictionary indices — no broadcast of repeated values. The index type is sized to the dimension length (int8/int16/int32), so a 6-step time chunk uses 1 byte/row and a 721/1440-point lat/lon uses 2. On an ERA5-shaped chunk the coordinate columns shrink ~4.8x (and equality GROUP BY / JOIN keys become small integers). - df.py: _parse_schema declares dimension coordinates as dictionary(index, value) with the value type/metadata preserved; iter_record_batches and dataset_to_record_batch emit DictionaryArrays. - src/lib.rs: keep partition pruning and exact statistics working on the new encoding — DataFusion coerces a coordinate filter to either a Dictionary literal (timestamp) or Cast(col AS value_type) (float), so compare_to_scalar unwraps Dictionary scalars, the pruning matchers strip decode casts, and bound_to_scalar / total_byte_size unwrap the dictionary value type. - tests: pin the encoding contract and the group-by round-trip; update schema and memory-characterization expectations to the (smaller) encoded columns. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N
Cap coordinate dictionary indices at the correct width for any cardinality: the selection now tiers int8 → int16 → int32 → int64 at exact `MAX + 1` boundaries (indices run 0..n-1, so a signed max M holds n = M+1). The int64 fallback keeps astronomically large coordinate axes representable instead of silently overflowing a 32-bit index. Follows the adaptive-key-width note in jayendra13's zarr-datafusion. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N
A narrow dictionary key (int8/int16) can overflow under DataFusion streaming aggregation: it concatenates the per-batch coordinate dictionaries across the aggregate and does not always unify them (arrow merges dictionary values on a size heuristic, not a guarantee), so the combined index for an unchunked coordinate repeated across N partitions can reach card × N and blow past the key type — "Dictionary key bigger than the key type" (reported by @ghostiee-11 on a float32 GROUP BY lat, lon). Make the encoding overflow-proof and only apply it where it pays off: - `_coord_index_type` floors the key at int32 (~2.1B combined entries covers any realistic grid; int64 backstops the rest). int8/int16 are gone. - `_as_dictionary_field` encodes a coordinate only when the int32 key is strictly narrower than the value type: 8-byte float64/int64/timestamp coordinates (a safe 2x) and variable-width strings, leaving 4-byte float32/int32 coordinates dense (where a dictionary is pure overhead and the only way to win would be an overflow-prone narrow key). - `total_byte_size` reports Inexact when any column is dictionary-encoded, since it sizes those by the value type (a safe upper bound) not the narrower index — honest now that the index is smaller (addresses the "not exact" note). Regression test: float32 GROUP BY lat, lon over 100 partitions matches xarray and no longer overflows. Stats byte-size expectation updated to Inexact. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N
Coordinate columns are dictionary-encoded internally; DataFusion surfaces those as pandas Categorical, which sorts by category order and trips dtype checks. Decode them back to their value dtype at the to_pandas boundary so callers see the same plain columns as before the encoding. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N
6f57863 to
44087ed
Compare
Encode wide coordinate columns as Arrow dictionaries so the query engine moves fewer bytes for them and groups/joins on integer keys instead of rehashing repeated values.
Prior art
This follows @jayendra13's dictionary-encoding idea in zarr-datafusion.
Why
A dense grid repeats every coordinate value across the whole partition — a chunk of shape
(time, lat, lon)carries each latitudetime × lontimes. The reader materialized coordinate columns as dense, fully-repeated Arrow arrays (iter_record_batches), so everyGROUP BY/JOINon a coordinate re-hashed a hugely redundant column and the pivot moved more bytes for coordinates than for the data itself.This is the pivot/round-trip tax that the recent exact-statistics work (#201) couldn't touch — stats fixed plan choice; this attacks the actual data volume.
What
df.py—_parse_schemadictionary-encodes a dimension coordinate when it is safe and worthwhile:card × Nand blow past a narrow key.float64/int64/timestamp(and variable-width strings). 4-bytefloat32/int32coordinates stay dense (a dictionary there would be pure overhead, and a narrower key is the unsafe one). Value type and cftime metadata are preserved.iter_record_batches/dataset_to_record_batchemitDictionaryArrays for the encoded columns; the strided per-row indices they already compute are the dictionary indices, so no dense array of repeated values is built.src/lib.rs— keep partition pruning and statistics correct on the encoding. DataFusion coerces a coordinate filter to either aDictionaryliteral (timestamps) orCast(col AS value_type)(numerics), socompare_to_scalarunwrapsScalarValue::Dictionary, the pruning matchersstrip_castthe lossless decode cast, andbound_to_scalarunwrapsDataType::Dictionary.total_byte_sizeis reportedInexactwhen any column is dictionary-encoded (it sizes those by the value type — a safe upper bound — not the narrower index).test_dict_coords.pypins which coordinates get encoded, a group-by round-trip, and the overflow regression (float32GROUP BY lat, lonover 100 partitions). cftime schema assertions, memory bounds, and the stats byte-size expectation updated.Measured
Coordinate-column bytes on an ERA5-shaped chunk (
time=6, lat=721, lon=1440, ~6.2M rows):The win is on the 8-byte coordinate columns (int32 key = half the width). float32 grids benefit only on their timestamp column; that's the safe ceiling — the larger multiples from narrow (int8/int16) keys were the part that overflowed under streaming aggregation.
Correctness / risk
The make-or-break risks were verified: partition pruning (a
timefilter still prunes to a single partition), the round-trip back to xarray (values decode viato_numpy), and the streaming-aggregate overflow reported by @ghostiee-11 (float32GROUP BY lat, lon) — now regression-tested and network-free so CI covers it. Full suite green (186) pluscargo fmt/clippy/testandruff/mypy.Possible follow-up: recover a larger win for chunked dimensions safely — there the combined dictionary is bounded by the full cardinality, so a narrow key can't overflow.
🤖 Generated with Claude Code
https://claude.ai/code/session_019VuSeCio99NcME5eubcN3N