Skip to content

Implementation Template

Skill: spark-python-data-source

The PySpark DataSource API lets you teach Spark how to read from and write to any external system — REST APIs, databases, message queues, custom protocols. You define a DataSource class that returns readers and writers, and Spark handles distribution across executors. Your AI coding assistant can generate the full skeleton so you focus on connector logic, not API plumbing.

“Generate a custom Spark data source in Python that reads from a REST API with pagination. The source should support both batch and streaming reads, plus batch writes. Use requests for HTTP calls and include proper executor-side imports.”

from pyspark.sql.datasource import (
DataSource, DataSourceReader, DataSourceWriter,
DataSourceStreamReader, DataSourceStreamWriter, InputPartition
)
class ApiPartition(InputPartition):
def __init__(self, partition_id, start, end):
self.partition_id = partition_id
self.start = start
self.end = end
# 1. Entry point — Spark calls this to get readers and writers
class RestApiDataSource(DataSource):
@classmethod
def name(cls):
return "rest-api"
def schema(self):
return self._infer_or_return_schema()
def reader(self, schema):
return RestApiBatchReader(self.options, schema)
def writer(self, schema, overwrite):
return RestApiBatchWriter(self.options, schema)
def streamReader(self, schema):
return RestApiStreamReader(self.options, schema)
# 2. Shared reader logic — plain class, not a PySpark interface
class RestApiReader:
def __init__(self, options, schema):
self.url = options.get("url")
assert self.url, "url is required"
self.page_size = int(options.get("page_size", "100"))
self.schema = schema
def read(self, partition):
import requests # Import here — runs on executors
response = requests.get(
self.url,
params={"offset": partition.start, "limit": self.page_size}
)
response.raise_for_status()
for item in response.json()["results"]:
yield tuple(item[f.name] for f in self.schema.fields)
# 3. Batch reader — mixes shared logic with PySpark interface
class RestApiBatchReader(RestApiReader, DataSourceReader):
def partitions(self):
import requests
total = requests.get(f"{self.url}/count").json()["total"]
return [
ApiPartition(i, i * self.page_size, min((i + 1) * self.page_size, total))
for i in range(0, total // self.page_size + 1)
]
# 4. Shared writer logic
class RestApiWriter:
def __init__(self, options, schema=None):
self.url = options.get("url")
assert self.url, "url is required"
self.batch_size = int(options.get("batch_size", "50"))
def write(self, iterator):
import requests # Import here — runs on executors
from pyspark import TaskContext
context = TaskContext.get()
msgs, count = [], 0
for row in iterator:
count += 1
msgs.append(row.asDict())
if len(msgs) >= self.batch_size:
requests.post(self.url, json=msgs).raise_for_status()
msgs = []
if msgs:
requests.post(self.url, json=msgs).raise_for_status()
# 5. Batch writer
class RestApiBatchWriter(RestApiWriter, DataSourceWriter):
pass
# 6. Stream reader — adds offset tracking
class RestApiStreamReader(RestApiReader, DataSourceStreamReader):
def initialOffset(self):
return {"offset": "0"}
def latestOffset(self):
import requests
total = requests.get(f"{self.url}/count").json()["total"]
return {"offset": str(total)}
def partitions(self, start, end):
s, e = int(start["offset"]), int(end["offset"])
return [ApiPartition(0, s, e)]
def commit(self, end):
pass

Key decisions your AI coding assistant made:

  • Flat inheritanceRestApiBatchReader inherits from both RestApiReader (shared logic) and DataSourceReader (PySpark interface). PySpark serializes reader instances to ship to executors, so deep class hierarchies break serialization.
  • Executor-side importsimport requests lives inside read() and write(), not at the top of the file. These methods run on executor processes that don’t share the driver’s module state.
  • partitions() drives parallelism — the batch reader’s partitions() method returns one InputPartition per chunk. Spark sends each partition to a separate executor, so more partitions means more parallel API calls.
  • read() yields tuples — each tuple maps positionally to the schema fields. Spark assembles them into rows on the executor side.

“Show me how to register my custom data source and use it for batch reads, batch writes, and streaming reads.”

spark.dataSource.register(RestApiDataSource)
# Batch read
df = spark.read.format("rest-api").option("url", "https://api.example.com/events").load()
# Batch write
df.write.format("rest-api").option("url", "https://api.example.com/events").save()
# Streaming read
stream_df = (
spark.readStream
.format("rest-api")
.option("url", "https://api.example.com/events")
.load()
)

The name() classmethod on your DataSource is what you pass to .format(). Once registered, Spark treats your custom source like any built-in format.

Add a streaming writer with commit semantics

Section titled “Add a streaming writer with commit semantics”

“Extend my REST API data source with a streaming writer that supports micro-batch commit and abort.”

class RestApiStreamWriter(RestApiWriter, DataSourceStreamWriter):
def commit(self, messages, batchId):
# Called after all partitions for this batch succeed
pass
def abort(self, messages, batchId):
# Called if any partition in this batch fails
pass

The commit() and abort() methods give you hooks for transactional semantics. If your target system supports two-phase commit, use write() to stage data and commit() to finalize it. For fire-and-forget APIs, these can be no-ops.

Project structure for a packaged data source

Section titled “Project structure for a packaged data source”

“Show me the project layout for a custom Spark data source I can build as a wheel and install on a Databricks cluster.”

my-rest-connector/
├── pyproject.toml
├── src/
│ └── my_rest_connector/
│ ├── __init__.py
│ └── datasource.py # DataSource, Reader, Writer classes
└── tests/
├── conftest.py # Spark session fixture
├── test_reader.py
└── test_writer.py

Package it as a wheel, install it on the cluster, and your notebooks can from my_rest_connector import RestApiDataSource before calling spark.dataSource.register().

  • Top-level imports of third-party libraries — your read() and write() methods run on executors that don’t share the driver’s Python environment. Import requests, database drivers, and SDK clients inside those methods, not at module level.
  • Deep inheritance hierarchies — PySpark serializes reader/writer instances with pickle to ship them to executors. Abstract base classes, multiple inheritance chains, and closures over unpicklable objects cause PicklingError at runtime. Stick to flat, single-level inheritance.
  • Forgetting partitions() — if your batch reader doesn’t override partitions(), Spark creates a single partition and reads everything sequentially on one executor. For any source with more than trivial data volume, implement partitions() to enable parallel reads.
  • Yielding dicts instead of tuples from read() — the read() method must yield tuples where each element maps positionally to the schema. Yielding dicts or Row objects causes schema mismatch errors.