Skip to content

Partitioning Patterns

Skill: spark-python-data-source

Partitioning is how your custom data source tells Spark to parallelize reads. Without it, Spark sends all work to a single executor. Your AI coding assistant can generate partitioning strategies that match your data source’s access patterns — time ranges for log APIs, ID ranges for paginated REST endpoints, token ranges for distributed databases like Cassandra.

“Add partition pruning to my custom data source so Spark only reads the date ranges it needs. The API supports start_time and end_time query parameters. Use Python with a configurable partition duration.”

from pyspark.sql.datasource import InputPartition, DataSourceReader
from datetime import datetime, timedelta
class TimeRangePartition(InputPartition):
def __init__(self, start_time, end_time):
self.start_time = start_time
self.end_time = end_time
class TimeBasedReader(DataSourceReader):
def __init__(self, options, schema):
self.url = options.get("url")
self.start_time = datetime.fromisoformat(options["start_time"])
self.end_time = datetime.fromisoformat(options["end_time"])
self.partition_seconds = int(options.get("partition_duration", "3600"))
self.schema = schema
def partitions(self):
"""Split the full time range into fixed-duration partitions."""
result = []
current = self.start_time
delta = timedelta(seconds=self.partition_seconds)
while current < self.end_time:
next_time = min(current + delta, self.end_time)
result.append(TimeRangePartition(current, next_time))
current = next_time
return result
def read(self, partition):
import requests
response = requests.get(self.url, params={
"start_time": partition.start_time.isoformat(),
"end_time": partition.end_time.isoformat(),
})
response.raise_for_status()
for item in response.json()["events"]:
yield tuple(item[f.name] for f in self.schema.fields)

Key decisions your AI coding assistant made:

  • partitions() returns the work units — Spark calls this on the driver, then ships each InputPartition to a different executor. More partitions means more parallelism, up to the number of available cores.
  • Fixed-duration windows — splitting by time works for APIs with temporal access patterns (logs, events, metrics). Each executor queries a non-overlapping time range.
  • Configurable granularitypartition_duration lets you tune parallelism without code changes. A 24-hour query with 1-hour partitions creates 24 parallel reads.

Auto-subdivide when the API returns too much data

Section titled “Auto-subdivide when the API returns too much data”

“My API returns a size-limit error when a time range contains more than 10,000 results. Add recursive subdivision to my reader so it splits large partitions automatically.”

class AutoSubdivideReader(DataSourceReader):
def __init__(self, options, schema):
self.url = options.get("url")
self.min_partition_seconds = int(options.get("min_partition_seconds", "60"))
self.schema = schema
def read(self, partition):
import requests
response = requests.get(self.url, params={
"start": partition.start_time.isoformat(),
"end": partition.end_time.isoformat(),
})
if response.status_code == 413 or "too large" in response.text.lower():
yield from self._subdivide_and_read(partition)
return
response.raise_for_status()
for item in response.json()["results"]:
yield tuple(item[f.name] for f in self.schema.fields)
def _subdivide_and_read(self, partition):
duration = (partition.end_time - partition.start_time).total_seconds()
if duration <= self.min_partition_seconds:
raise RuntimeError(
f"Cannot subdivide further at {duration}s. "
f"Increase min_partition_seconds or filter the query."
)
midpoint = partition.start_time + timedelta(seconds=duration / 2)
yield from self.read(TimeRangePartition(partition.start_time, midpoint))
yield from self.read(TimeRangePartition(midpoint, partition.end_time))

Recursive subdivision handles the unpredictable case where some time windows have far more data than others. The min_partition_seconds guard prevents infinite recursion.

“Generate ID-range partitioning for a REST API that supports offset and limit query parameters. Query the total count first, then split into parallel ranges.”

class IdRangePartition(InputPartition):
def __init__(self, partition_id, start_id, end_id):
self.partition_id = partition_id
self.start_id = start_id
self.end_id = end_id
class IdRangeReader(DataSourceReader):
def __init__(self, options, schema):
self.url = options.get("url")
self.num_partitions = int(options.get("num_partitions", "4"))
self.page_size = int(options.get("page_size", "1000"))
self.schema = schema
def partitions(self):
import requests
total = requests.get(f"{self.url}/count").json()["total"]
chunk = total // self.num_partitions
return [
IdRangePartition(i, i * chunk,
(i + 1) * chunk if i < self.num_partitions - 1 else total)
for i in range(self.num_partitions)
]
def read(self, partition):
import requests
current = partition.start_id
while current < partition.end_id:
response = requests.get(self.url, params={
"offset": current, "limit": self.page_size
})
response.raise_for_status()
for item in response.json()["items"]:
yield tuple(item[f.name] for f in self.schema.fields)
current += self.page_size

Each executor pages through its own ID range independently. The trade-off: you need an API endpoint that returns the total count, and the data distribution across IDs should be roughly uniform.

Token-range partitioning for distributed databases

Section titled “Token-range partitioning for distributed databases”

“Generate token-range partitioning for a Cassandra data source that queries by the cluster’s token ring, creating one partition per vnode range.”

class TokenRangePartition(InputPartition):
def __init__(self, partition_id, start_token, end_token, pk_columns):
self.partition_id = partition_id
self.start_token = start_token
self.end_token = end_token
self.pk_columns = pk_columns
class TokenRangeReader(DataSourceReader):
def partitions(self):
ring = sorted(self.token_map.ring)
partitions = []
for i, token in enumerate(ring):
next_token = ring[(i + 1) % len(ring)]
partitions.append(TokenRangePartition(
i,
token.value if hasattr(token, "value") else str(token),
next_token.value if hasattr(next_token, "value") else str(next_token),
self.pk_columns,
))
return partitions
def read(self, partition):
pk_cols = ", ".join(partition.pk_columns)
where = (
f"token({pk_cols}) > {partition.start_token} "
f"AND token({pk_cols}) <= {partition.end_token}"
)
query = f"SELECT * FROM {self.table} WHERE {where}"
for row in self._execute_query(query):
yield row

Token-range partitioning is the natural fit for Cassandra, ScyllaDB, and other ring-based databases. Each partition maps to a vnode range, so reads distribute evenly across the database cluster without hotspots.

  • Too few partitions — if you return one partition from partitions(), Spark reads everything on a single executor regardless of cluster size. Start with 2-4x the number of executor cores and tune from there.
  • Too many partitions for rate-limited APIs — 100 partitions means 100 concurrent API calls. If the API rate-limits at 50 requests/second, half your executors will be retrying. Match partition count to what the external system can handle.
  • Skewed partition sizes — if one time range has 10x more data than others, that executor becomes the bottleneck. Auto-subdivision (shown above) or adaptive partition sizing helps, but monitor for skew in production.
  • Overlapping partition boundaries — if partition N covers [0, 100] and partition N+1 covers [100, 200], row 100 gets read twice. Use exclusive start bounds ((100, 200]) or non-overlapping ranges to prevent duplicates.