Skip to content

Streaming Patterns

Skill: spark-python-data-source

A streaming data source extends your batch reader with offset management — Spark asks “what’s new since last time?” and your reader answers with partitions covering only the new data. Your AI coding assistant can generate the offset tracking, non-overlapping partition logic, and exactly-once write semantics that make a custom streaming source production-ready.

“Convert my batch custom data source to support Spark Structured Streaming with offset tracking. The API returns events with timestamps, and I need non-overlapping time-based partitions across micro-batches. Use Python.”

import json
from datetime import datetime, timedelta, timezone
from pyspark.sql.datasource import DataSourceStreamReader, InputPartition
class TimeRangePartition(InputPartition):
def __init__(self, start_time, end_time):
self.start_time = start_time
self.end_time = end_time
class EventStreamReader(DataSourceStreamReader):
def __init__(self, options, schema):
self.url = options.get("url")
self.partition_seconds = int(options.get("partition_duration", "3600"))
start = options.get("start_time", "latest")
if start == "latest":
self._start = datetime.now(timezone.utc)
else:
self._start = datetime.fromisoformat(start.replace("Z", "+00:00"))
def initialOffset(self):
# Subtract 1us to compensate for the +1us in partitions()
adjusted = self._start - timedelta(microseconds=1)
return json.dumps({"ts": adjusted.isoformat()})
def latestOffset(self):
return json.dumps({"ts": datetime.now(timezone.utc).isoformat()})
def partitions(self, start, end):
start_ts = datetime.fromisoformat(json.loads(start)["ts"])
end_ts = datetime.fromisoformat(json.loads(end)["ts"])
# +1us prevents overlap with the previous batch's last partition
start_ts = start_ts + timedelta(microseconds=1)
partitions = []
current = start_ts
delta = timedelta(seconds=self.partition_seconds)
while current < end_ts:
next_time = min(current + delta, end_ts)
partitions.append(TimeRangePartition(current, next_time))
current = next_time + timedelta(microseconds=1)
return partitions if partitions else [TimeRangePartition(start_ts, end_ts)]
def commit(self, end):
pass # Spark handles checkpointing
def read(self, partition):
import requests
response = requests.get(self.url, params={
"start": partition.start_time.isoformat(),
"end": partition.end_time.isoformat(),
})
response.raise_for_status()
for event in response.json()["events"]:
yield tuple(event.values())

Key decisions your AI coding assistant made:

  • SimpleDataSourceStreamReader vs batch reader — the streaming reader adds four methods: initialOffset(), latestOffset(), partitions(start, end), and commit(end). The batch reader’s partitions() takes no arguments. That difference is the entire streaming contract.
  • Offset as JSON — offsets must be JSON-serializable strings. Spark stores them in the checkpoint directory. You can put anything in there — timestamps, sequence IDs, cursor tokens — as long as it roundtrips through json.dumps/json.loads.
  • Microsecond adjustment prevents overlapinitialOffset() subtracts 1 microsecond, and partitions() adds 1 microsecond back. This ensures batch N’s last record isn’t re-read in batch N+1.
  • commit() is a no-op here — Spark’s checkpoint mechanism tracks offsets. You only need commit() logic if your external system requires explicit acknowledgment (e.g., marking messages as consumed).

“My streaming source needs to track both a timestamp and a sequence ID to guarantee ordering. Generate a multi-field offset class.”

class CompositeOffset:
def __init__(self, timestamp, sequence_id):
self.timestamp = timestamp
self.sequence_id = sequence_id
def json(self):
return json.dumps({
"timestamp": self.timestamp,
"sequence_id": self.sequence_id,
})
@staticmethod
def from_json(raw):
data = json.loads(raw)
return CompositeOffset(data["timestamp"], data["sequence_id"])
def __lt__(self, other):
if self.timestamp != other.timestamp:
return self.timestamp < other.timestamp
return self.sequence_id < other.sequence_id

When timestamps alone aren’t unique (e.g., multiple events per millisecond), adding a sequence ID gives you a total ordering. The __lt__ method lets Spark compare offsets to determine progress.

“Generate a streaming writer with idempotency keys so that retries after failures don’t produce duplicate records in the target system.”

import hashlib
from pyspark.sql.datasource import DataSourceStreamWriter
from pyspark import TaskContext
class IdempotentStreamWriter(DataSourceStreamWriter):
def __init__(self, options, schema):
self.url = options.get("url")
def write(self, iterator):
import requests
context = TaskContext.get()
partition_id = context.partitionId()
for row in iterator:
row_dict = row.asDict()
# Deterministic key from batch + partition + content
key_data = json.dumps({
"partition": partition_id,
"row": row_dict,
}, sort_keys=True)
idempotency_key = hashlib.sha256(key_data.encode()).hexdigest()
requests.post(self.url, json={
"data": row_dict,
"idempotency_key": idempotency_key,
}).raise_for_status()
def commit(self, messages, batchId):
pass # Target system uses idempotency keys for dedup
def abort(self, messages, batchId):
pass # Idempotent writes make abort safe — retried writes are no-ops

The idempotency key is a hash of the partition ID and row content. If Spark retries a failed batch, the target system sees the same key and skips the duplicate. This requires the target system to support idempotency (most modern APIs do via Idempotency-Key headers or upsert semantics).

Track streaming progress with partition metrics

Section titled “Track streaming progress with partition metrics”

“Add per-partition metrics to my stream reader so I can monitor throughput and detect stalls.”

class MonitoredStreamReader(DataSourceStreamReader):
def read(self, partition):
from datetime import datetime
import time
start = time.time()
count = 0
for row in self._fetch_partition(partition):
count += 1
yield row
duration = time.time() - start
throughput = count / duration if duration > 0 else 0
print(f"Partition [{partition.start_time} - {partition.end_time}]: "
f"{count} rows in {duration:.1f}s ({throughput:.0f} rows/s)")

These metrics show up in executor logs. For production monitoring, emit them to your metrics system instead of print().

  • Overlapping partitions across micro-batches — if partitions(start, end) doesn’t properly exclude the previous batch’s end boundary, you’ll get duplicate rows on every micro-batch. The microsecond adjustment pattern shown above prevents this.
  • Non-deterministic latestOffset() — Spark may call latestOffset() multiple times per micro-batch for planning. If it returns a different value each time (e.g., datetime.now() with no rounding), Spark may skip or duplicate data. Consider rounding to the nearest second.
  • Forgetting that commit() can be a no-op — Spark’s checkpoint mechanism tracks offsets for you. Only implement commit() if your external system needs explicit acknowledgment (Kafka consumer offsets, message queue ACKs). For REST APIs, leave it empty.
  • Streaming writer without idempotency — if a micro-batch fails after partial writes and Spark retries, you’ll get duplicates in the target system. Either implement idempotency keys (shown above) or use a target system that supports upsert semantics.