Custom Spark Data Sources
Skill: spark-python-data-source
What You Can Build
Section titled “What You Can Build”When there is no native Spark connector for your target system, you build one. The PySpark DataSource API (Spark 4.0+) lets you write Python classes that Spark calls during read, write, readStream, and writeStream operations. Your AI coding assistant can scaffold the entire connector — DataSource entry point, batch and streaming reader/writer classes, partitioning strategy, and tests — from a description of the external system.
In Action
Section titled “In Action”“Build a custom Spark data source that reads from a REST API with cursor-based pagination and writes transformed results to a second API endpoint.”
from pyspark.sql.datasource import DataSource, DataSourceReader, DataSourceWriterfrom pyspark.sql.types import StructType, StructField, StringType, IntegerType
class RestApiDataSource(DataSource): """Spark data source for a paginated REST API."""
@classmethod def name(cls) -> str: return "rest-api"
def schema(self) -> StructType: return StructType([ StructField("id", IntegerType()), StructField("name", StringType()), StructField("status", StringType()), ])
def reader(self, schema: StructType) -> "RestApiBatchReader": return RestApiBatchReader(self.options)
def writer(self, schema: StructType, overwrite: bool) -> "RestApiBatchWriter": return RestApiBatchWriter(self.options)
class RestApiBatchReader(DataSourceReader): def __init__(self, options): self.url = options.get("url") self.page_size = int(options.get("page_size", "100"))
def read(self, partition): import requests # import inside executor method
cursor = None while True: params = \{"limit": self.page_size\} if cursor: params["cursor"] = cursor resp = requests.get(self.url, params=params) resp.raise_for_status() data = resp.json() for item in data["results"]: yield (item["id"], item["name"], item["status"]) cursor = data.get("next_cursor") if not cursor: break
class RestApiBatchWriter(DataSourceWriter): def __init__(self, options): self.url = options.get("target_url")
def write(self, iterator): import requests
batch = [] for row in iterator: batch.append(\{"id": row.id, "name": row.name, "status": row.status\}) if len(batch) >= 500: requests.post(self.url, json=batch).raise_for_status() batch = [] if batch: requests.post(self.url, json=batch).raise_for_status()Key decisions:
- Flat single-level inheritance — PySpark serializes reader/writer instances to ship them to executors. Deep class hierarchies break serialization. One shared base mixed with the PySpark interface is all you need.
- Third-party imports inside executor methods —
read()andwrite()run on remote executors that do not share the driver’s Python environment. Importrequests, database drivers, etc. inside these methods, never at module top level. - Cursor-based pagination — chosen over offset-based because the source API supports it and it avoids the classic “skipped rows on concurrent inserts” problem.
- Batched writes at 500 rows — reduces HTTP round-trips without holding too much data in memory on executors.
More Patterns
Section titled “More Patterns”Streaming reader with offset tracking
Section titled “Streaming reader with offset tracking”“Create a streaming data source for a message queue that tracks consumer offsets for exactly-once processing.”
from pyspark.sql.datasource import DataSourceStreamReaderimport json
class QueueStreamReader(DataSourceStreamReader): def __init__(self, options): self.broker = options.get("broker") self.queue = options.get("queue")
def initialOffset(self) -> dict: return \{"offset": 0\}
def latestOffset(self) -> dict: import pika # import on executor conn = pika.BlockingConnection(pika.URLParameters(self.broker)) ch = conn.channel() q = ch.queue_declare(queue=self.queue, passive=True) conn.close() return \{"offset": q.method.message_count\}
def read(self, start: dict, end: dict): import pika conn = pika.BlockingConnection(pika.URLParameters(self.broker)) ch = conn.channel() count = end["offset"] - start["offset"] for _ in range(count): method, props, body = ch.basic_get(queue=self.queue) if method: yield (method.delivery_tag, body.decode("utf-8")) conn.close()
def commit(self, end: dict): pass # ACKs handled in readStreaming offsets must be JSON-serializable dicts. The latestOffset method runs on the driver to discover new data, while read runs on executors to fetch it. Keep the boundary between these two methods clean.
Authentication with fallback chain
Section titled “Authentication with fallback chain”“Add multi-method auth to my data source — try Unity Catalog credentials first, then fall back to environment variables.”
class AuthenticatedReader(DataSourceReader): def __init__(self, options): self.api_key = options.get("api_key") self.uc_secret_scope = options.get("secret_scope") self.uc_secret_key = options.get("secret_key")
def _resolve_credentials(self): """Priority: explicit option > UC secrets > env var.""" if self.api_key: return self.api_key
if self.uc_secret_scope and self.uc_secret_key: from pyspark.sql import SparkSession spark = SparkSession.getActiveSession() return spark.conf.get( f"spark.databricks.secrets.{self.uc_secret_scope}.{self.uc_secret_key}" )
import os key = os.environ.get("API_KEY") if not key: raise ValueError("No credentials found -- set api_key option, UC secret, or API_KEY env var") return key
def read(self, partition): import requests headers = \{"Authorization": f"Bearer \{self._resolve_credentials()\}"\} resp = requests.get("https://api.example.com/data", headers=headers) for row in resp.json(): yield tuple(row.values())The fallback chain keeps connectors portable — they work in notebooks with UC secrets and in CI with environment variables, with no code changes.
Partitioned parallel reads
Section titled “Partitioned parallel reads”“My database has 100M rows. Partition reads across Spark executors using ID-range splits.”
class PartitionedReader(DataSourceReader): def __init__(self, options): self.url = options.get("url") self.num_partitions = int(options.get("partitions", "8"))
def partitions(self): import requests resp = requests.get(f"\{self.url\}/stats") max_id = resp.json()["max_id"] chunk = max_id // self.num_partitions return [ \{"start": i * chunk, "end": (i + 1) * chunk if i < self.num_partitions - 1 else max_id\} for i in range(self.num_partitions) ]
def read(self, partition): import requests resp = requests.get( f"\{self.url\}/records", params=\{"id_gte": partition["start"], "id_lt": partition["end"]\} ) for row in resp.json(): yield (row["id"], row["value"])The partitions() method runs once on the driver and returns partition specs. Each executor then calls read() with one spec. Choose partition boundaries that produce roughly equal-sized chunks so no executor becomes a bottleneck.
Watch Out For
Section titled “Watch Out For”- Top-level third-party imports — These will pass on the driver but throw
ModuleNotFoundErroron executors. Always import external libraries insideread()andwrite()methods. - Complex inheritance hierarchies — PySpark pickles reader/writer instances for shipping to executors. Abstract base classes, multiple inheritance layers, and closures over unpicklable objects will fail at runtime with cryptic serialization errors. Stick to flat, single-level inheritance.
- Missing cluster-wide package installs — Every dependency must be installed on all executor nodes, not just the driver. Use
%pip installin notebooks or init scripts on clusters. If a package is only on the driver, executors will fail silently or crash. - Overlapping stream partitions — If
latestOffsetboundaries overlap, records get processed twice. Make sure partition ranges are strictly non-overlapping with no gaps.