Streaming Patterns
Skill: spark-python-data-source
What You Can Build
Section titled “What You Can Build”A streaming data source extends your batch reader with offset management — Spark asks “what’s new since last time?” and your reader answers with partitions covering only the new data. Your AI coding assistant can generate the offset tracking, non-overlapping partition logic, and exactly-once write semantics that make a custom streaming source production-ready.
In Action
Section titled “In Action”“Convert my batch custom data source to support Spark Structured Streaming with offset tracking. The API returns events with timestamps, and I need non-overlapping time-based partitions across micro-batches. Use Python.”
import jsonfrom datetime import datetime, timedelta, timezonefrom pyspark.sql.datasource import DataSourceStreamReader, InputPartition
class TimeRangePartition(InputPartition): def __init__(self, start_time, end_time): self.start_time = start_time self.end_time = end_time
class EventStreamReader(DataSourceStreamReader): def __init__(self, options, schema): self.url = options.get("url") self.partition_seconds = int(options.get("partition_duration", "3600")) start = options.get("start_time", "latest")
if start == "latest": self._start = datetime.now(timezone.utc) else: self._start = datetime.fromisoformat(start.replace("Z", "+00:00"))
def initialOffset(self): # Subtract 1us to compensate for the +1us in partitions() adjusted = self._start - timedelta(microseconds=1) return json.dumps({"ts": adjusted.isoformat()})
def latestOffset(self): return json.dumps({"ts": datetime.now(timezone.utc).isoformat()})
def partitions(self, start, end): start_ts = datetime.fromisoformat(json.loads(start)["ts"]) end_ts = datetime.fromisoformat(json.loads(end)["ts"])
# +1us prevents overlap with the previous batch's last partition start_ts = start_ts + timedelta(microseconds=1)
partitions = [] current = start_ts delta = timedelta(seconds=self.partition_seconds)
while current < end_ts: next_time = min(current + delta, end_ts) partitions.append(TimeRangePartition(current, next_time)) current = next_time + timedelta(microseconds=1)
return partitions if partitions else [TimeRangePartition(start_ts, end_ts)]
def commit(self, end): pass # Spark handles checkpointing
def read(self, partition): import requests
response = requests.get(self.url, params={ "start": partition.start_time.isoformat(), "end": partition.end_time.isoformat(), }) response.raise_for_status() for event in response.json()["events"]: yield tuple(event.values())Key decisions your AI coding assistant made:
SimpleDataSourceStreamReadervs batch reader — the streaming reader adds four methods:initialOffset(),latestOffset(),partitions(start, end), andcommit(end). The batch reader’spartitions()takes no arguments. That difference is the entire streaming contract.- Offset as JSON — offsets must be JSON-serializable strings. Spark stores them in the checkpoint directory. You can put anything in there — timestamps, sequence IDs, cursor tokens — as long as it roundtrips through
json.dumps/json.loads. - Microsecond adjustment prevents overlap —
initialOffset()subtracts 1 microsecond, andpartitions()adds 1 microsecond back. This ensures batch N’s last record isn’t re-read in batch N+1. commit()is a no-op here — Spark’s checkpoint mechanism tracks offsets. You only needcommit()logic if your external system requires explicit acknowledgment (e.g., marking messages as consumed).
More Patterns
Section titled “More Patterns”Multi-field offset for complex sources
Section titled “Multi-field offset for complex sources”“My streaming source needs to track both a timestamp and a sequence ID to guarantee ordering. Generate a multi-field offset class.”
class CompositeOffset: def __init__(self, timestamp, sequence_id): self.timestamp = timestamp self.sequence_id = sequence_id
def json(self): return json.dumps({ "timestamp": self.timestamp, "sequence_id": self.sequence_id, })
@staticmethod def from_json(raw): data = json.loads(raw) return CompositeOffset(data["timestamp"], data["sequence_id"])
def __lt__(self, other): if self.timestamp != other.timestamp: return self.timestamp < other.timestamp return self.sequence_id < other.sequence_idWhen timestamps alone aren’t unique (e.g., multiple events per millisecond), adding a sequence ID gives you a total ordering. The __lt__ method lets Spark compare offsets to determine progress.
Exactly-once streaming writer
Section titled “Exactly-once streaming writer”“Generate a streaming writer with idempotency keys so that retries after failures don’t produce duplicate records in the target system.”
import hashlibfrom pyspark.sql.datasource import DataSourceStreamWriterfrom pyspark import TaskContext
class IdempotentStreamWriter(DataSourceStreamWriter): def __init__(self, options, schema): self.url = options.get("url")
def write(self, iterator): import requests
context = TaskContext.get() partition_id = context.partitionId()
for row in iterator: row_dict = row.asDict() # Deterministic key from batch + partition + content key_data = json.dumps({ "partition": partition_id, "row": row_dict, }, sort_keys=True) idempotency_key = hashlib.sha256(key_data.encode()).hexdigest()
requests.post(self.url, json={ "data": row_dict, "idempotency_key": idempotency_key, }).raise_for_status()
def commit(self, messages, batchId): pass # Target system uses idempotency keys for dedup
def abort(self, messages, batchId): pass # Idempotent writes make abort safe — retried writes are no-opsThe idempotency key is a hash of the partition ID and row content. If Spark retries a failed batch, the target system sees the same key and skips the duplicate. This requires the target system to support idempotency (most modern APIs do via Idempotency-Key headers or upsert semantics).
Track streaming progress with partition metrics
Section titled “Track streaming progress with partition metrics”“Add per-partition metrics to my stream reader so I can monitor throughput and detect stalls.”
class MonitoredStreamReader(DataSourceStreamReader): def read(self, partition): from datetime import datetime import time
start = time.time() count = 0
for row in self._fetch_partition(partition): count += 1 yield row
duration = time.time() - start throughput = count / duration if duration > 0 else 0 print(f"Partition [{partition.start_time} - {partition.end_time}]: " f"{count} rows in {duration:.1f}s ({throughput:.0f} rows/s)")These metrics show up in executor logs. For production monitoring, emit them to your metrics system instead of print().
Watch Out For
Section titled “Watch Out For”- Overlapping partitions across micro-batches — if
partitions(start, end)doesn’t properly exclude the previous batch’s end boundary, you’ll get duplicate rows on every micro-batch. The microsecond adjustment pattern shown above prevents this. - Non-deterministic
latestOffset()— Spark may calllatestOffset()multiple times per micro-batch for planning. If it returns a different value each time (e.g.,datetime.now()with no rounding), Spark may skip or duplicate data. Consider rounding to the nearest second. - Forgetting that
commit()can be a no-op — Spark’s checkpoint mechanism tracks offsets for you. Only implementcommit()if your external system needs explicit acknowledgment (Kafka consumer offsets, message queue ACKs). For REST APIs, leave it empty. - Streaming writer without idempotency — if a micro-batch fails after partial writes and Spark retries, you’ll get duplicates in the target system. Either implement idempotency keys (shown above) or use a target system that supports upsert semantics.