Skip to content

Custom Spark Data Sources

Skill: spark-python-data-source

When there is no native Spark connector for your target system, you build one. The PySpark DataSource API (Spark 4.0+) lets you write Python classes that Spark calls during read, write, readStream, and writeStream operations. Your AI coding assistant can scaffold the entire connector — DataSource entry point, batch and streaming reader/writer classes, partitioning strategy, and tests — from a description of the external system.

“Build a custom Spark data source that reads from a REST API with cursor-based pagination and writes transformed results to a second API endpoint.”

from pyspark.sql.datasource import DataSource, DataSourceReader, DataSourceWriter
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
class RestApiDataSource(DataSource):
"""Spark data source for a paginated REST API."""
@classmethod
def name(cls) -> str:
return "rest-api"
def schema(self) -> StructType:
return StructType([
StructField("id", IntegerType()),
StructField("name", StringType()),
StructField("status", StringType()),
])
def reader(self, schema: StructType) -> "RestApiBatchReader":
return RestApiBatchReader(self.options)
def writer(self, schema: StructType, overwrite: bool) -> "RestApiBatchWriter":
return RestApiBatchWriter(self.options)
class RestApiBatchReader(DataSourceReader):
def __init__(self, options):
self.url = options.get("url")
self.page_size = int(options.get("page_size", "100"))
def read(self, partition):
import requests # import inside executor method
cursor = None
while True:
params = \{"limit": self.page_size\}
if cursor:
params["cursor"] = cursor
resp = requests.get(self.url, params=params)
resp.raise_for_status()
data = resp.json()
for item in data["results"]:
yield (item["id"], item["name"], item["status"])
cursor = data.get("next_cursor")
if not cursor:
break
class RestApiBatchWriter(DataSourceWriter):
def __init__(self, options):
self.url = options.get("target_url")
def write(self, iterator):
import requests
batch = []
for row in iterator:
batch.append(\{"id": row.id, "name": row.name, "status": row.status\})
if len(batch) >= 500:
requests.post(self.url, json=batch).raise_for_status()
batch = []
if batch:
requests.post(self.url, json=batch).raise_for_status()

Key decisions:

  • Flat single-level inheritance — PySpark serializes reader/writer instances to ship them to executors. Deep class hierarchies break serialization. One shared base mixed with the PySpark interface is all you need.
  • Third-party imports inside executor methodsread() and write() run on remote executors that do not share the driver’s Python environment. Import requests, database drivers, etc. inside these methods, never at module top level.
  • Cursor-based pagination — chosen over offset-based because the source API supports it and it avoids the classic “skipped rows on concurrent inserts” problem.
  • Batched writes at 500 rows — reduces HTTP round-trips without holding too much data in memory on executors.

“Create a streaming data source for a message queue that tracks consumer offsets for exactly-once processing.”

from pyspark.sql.datasource import DataSourceStreamReader
import json
class QueueStreamReader(DataSourceStreamReader):
def __init__(self, options):
self.broker = options.get("broker")
self.queue = options.get("queue")
def initialOffset(self) -> dict:
return \{"offset": 0\}
def latestOffset(self) -> dict:
import pika # import on executor
conn = pika.BlockingConnection(pika.URLParameters(self.broker))
ch = conn.channel()
q = ch.queue_declare(queue=self.queue, passive=True)
conn.close()
return \{"offset": q.method.message_count\}
def read(self, start: dict, end: dict):
import pika
conn = pika.BlockingConnection(pika.URLParameters(self.broker))
ch = conn.channel()
count = end["offset"] - start["offset"]
for _ in range(count):
method, props, body = ch.basic_get(queue=self.queue)
if method:
yield (method.delivery_tag, body.decode("utf-8"))
conn.close()
def commit(self, end: dict):
pass # ACKs handled in read

Streaming offsets must be JSON-serializable dicts. The latestOffset method runs on the driver to discover new data, while read runs on executors to fetch it. Keep the boundary between these two methods clean.

“Add multi-method auth to my data source — try Unity Catalog credentials first, then fall back to environment variables.”

class AuthenticatedReader(DataSourceReader):
def __init__(self, options):
self.api_key = options.get("api_key")
self.uc_secret_scope = options.get("secret_scope")
self.uc_secret_key = options.get("secret_key")
def _resolve_credentials(self):
"""Priority: explicit option > UC secrets > env var."""
if self.api_key:
return self.api_key
if self.uc_secret_scope and self.uc_secret_key:
from pyspark.sql import SparkSession
spark = SparkSession.getActiveSession()
return spark.conf.get(
f"spark.databricks.secrets.{self.uc_secret_scope}.{self.uc_secret_key}"
)
import os
key = os.environ.get("API_KEY")
if not key:
raise ValueError("No credentials found -- set api_key option, UC secret, or API_KEY env var")
return key
def read(self, partition):
import requests
headers = \{"Authorization": f"Bearer \{self._resolve_credentials()\}"\}
resp = requests.get("https://api.example.com/data", headers=headers)
for row in resp.json():
yield tuple(row.values())

The fallback chain keeps connectors portable — they work in notebooks with UC secrets and in CI with environment variables, with no code changes.

“My database has 100M rows. Partition reads across Spark executors using ID-range splits.”

class PartitionedReader(DataSourceReader):
def __init__(self, options):
self.url = options.get("url")
self.num_partitions = int(options.get("partitions", "8"))
def partitions(self):
import requests
resp = requests.get(f"\{self.url\}/stats")
max_id = resp.json()["max_id"]
chunk = max_id // self.num_partitions
return [
\{"start": i * chunk, "end": (i + 1) * chunk if i < self.num_partitions - 1 else max_id\}
for i in range(self.num_partitions)
]
def read(self, partition):
import requests
resp = requests.get(
f"\{self.url\}/records",
params=\{"id_gte": partition["start"], "id_lt": partition["end"]\}
)
for row in resp.json():
yield (row["id"], row["value"])

The partitions() method runs once on the driver and returns partition specs. Each executor then calls read() with one spec. Choose partition boundaries that produce roughly equal-sized chunks so no executor becomes a bottleneck.

  • Top-level third-party imports — These will pass on the driver but throw ModuleNotFoundError on executors. Always import external libraries inside read() and write() methods.
  • Complex inheritance hierarchies — PySpark pickles reader/writer instances for shipping to executors. Abstract base classes, multiple inheritance layers, and closures over unpicklable objects will fail at runtime with cryptic serialization errors. Stick to flat, single-level inheritance.
  • Missing cluster-wide package installs — Every dependency must be installed on all executor nodes, not just the driver. Use %pip install in notebooks or init scripts on clusters. If a package is only on the driver, executors will fail silently or crash.
  • Overlapping stream partitions — If latestOffset boundaries overlap, records get processed twice. Make sure partition ranges are strictly non-overlapping with no gaps.