Implementation Template
Skill: spark-python-data-source
What You Can Build
Section titled “What You Can Build”The PySpark DataSource API lets you teach Spark how to read from and write to any external system — REST APIs, databases, message queues, custom protocols. You define a DataSource class that returns readers and writers, and Spark handles distribution across executors. Your AI coding assistant can generate the full skeleton so you focus on connector logic, not API plumbing.
In Action
Section titled “In Action”“Generate a custom Spark data source in Python that reads from a REST API with pagination. The source should support both batch and streaming reads, plus batch writes. Use
requestsfor HTTP calls and include proper executor-side imports.”
from pyspark.sql.datasource import ( DataSource, DataSourceReader, DataSourceWriter, DataSourceStreamReader, DataSourceStreamWriter, InputPartition)
class ApiPartition(InputPartition): def __init__(self, partition_id, start, end): self.partition_id = partition_id self.start = start self.end = end
# 1. Entry point — Spark calls this to get readers and writersclass RestApiDataSource(DataSource): @classmethod def name(cls): return "rest-api"
def schema(self): return self._infer_or_return_schema()
def reader(self, schema): return RestApiBatchReader(self.options, schema)
def writer(self, schema, overwrite): return RestApiBatchWriter(self.options, schema)
def streamReader(self, schema): return RestApiStreamReader(self.options, schema)
# 2. Shared reader logic — plain class, not a PySpark interfaceclass RestApiReader: def __init__(self, options, schema): self.url = options.get("url") assert self.url, "url is required" self.page_size = int(options.get("page_size", "100")) self.schema = schema
def read(self, partition): import requests # Import here — runs on executors
response = requests.get( self.url, params={"offset": partition.start, "limit": self.page_size} ) response.raise_for_status() for item in response.json()["results"]: yield tuple(item[f.name] for f in self.schema.fields)
# 3. Batch reader — mixes shared logic with PySpark interfaceclass RestApiBatchReader(RestApiReader, DataSourceReader): def partitions(self): import requests total = requests.get(f"{self.url}/count").json()["total"] return [ ApiPartition(i, i * self.page_size, min((i + 1) * self.page_size, total)) for i in range(0, total // self.page_size + 1) ]
# 4. Shared writer logicclass RestApiWriter: def __init__(self, options, schema=None): self.url = options.get("url") assert self.url, "url is required" self.batch_size = int(options.get("batch_size", "50"))
def write(self, iterator): import requests # Import here — runs on executors from pyspark import TaskContext
context = TaskContext.get() msgs, count = [], 0 for row in iterator: count += 1 msgs.append(row.asDict()) if len(msgs) >= self.batch_size: requests.post(self.url, json=msgs).raise_for_status() msgs = [] if msgs: requests.post(self.url, json=msgs).raise_for_status()
# 5. Batch writerclass RestApiBatchWriter(RestApiWriter, DataSourceWriter): pass
# 6. Stream reader — adds offset trackingclass RestApiStreamReader(RestApiReader, DataSourceStreamReader): def initialOffset(self): return {"offset": "0"}
def latestOffset(self): import requests total = requests.get(f"{self.url}/count").json()["total"] return {"offset": str(total)}
def partitions(self, start, end): s, e = int(start["offset"]), int(end["offset"]) return [ApiPartition(0, s, e)]
def commit(self, end): passKey decisions your AI coding assistant made:
- Flat inheritance —
RestApiBatchReaderinherits from bothRestApiReader(shared logic) andDataSourceReader(PySpark interface). PySpark serializes reader instances to ship to executors, so deep class hierarchies break serialization. - Executor-side imports —
import requestslives insideread()andwrite(), not at the top of the file. These methods run on executor processes that don’t share the driver’s module state. partitions()drives parallelism — the batch reader’spartitions()method returns oneInputPartitionper chunk. Spark sends each partition to a separate executor, so more partitions means more parallel API calls.read()yields tuples — each tuple maps positionally to the schema fields. Spark assembles them into rows on the executor side.
More Patterns
Section titled “More Patterns”Register and query the data source
Section titled “Register and query the data source”“Show me how to register my custom data source and use it for batch reads, batch writes, and streaming reads.”
spark.dataSource.register(RestApiDataSource)
# Batch readdf = spark.read.format("rest-api").option("url", "https://api.example.com/events").load()
# Batch writedf.write.format("rest-api").option("url", "https://api.example.com/events").save()
# Streaming readstream_df = ( spark.readStream .format("rest-api") .option("url", "https://api.example.com/events") .load())The name() classmethod on your DataSource is what you pass to .format(). Once registered, Spark treats your custom source like any built-in format.
Add a streaming writer with commit semantics
Section titled “Add a streaming writer with commit semantics”“Extend my REST API data source with a streaming writer that supports micro-batch commit and abort.”
class RestApiStreamWriter(RestApiWriter, DataSourceStreamWriter): def commit(self, messages, batchId): # Called after all partitions for this batch succeed pass
def abort(self, messages, batchId): # Called if any partition in this batch fails passThe commit() and abort() methods give you hooks for transactional semantics. If your target system supports two-phase commit, use write() to stage data and commit() to finalize it. For fire-and-forget APIs, these can be no-ops.
Project structure for a packaged data source
Section titled “Project structure for a packaged data source”“Show me the project layout for a custom Spark data source I can build as a wheel and install on a Databricks cluster.”
my-rest-connector/├── pyproject.toml├── src/│ └── my_rest_connector/│ ├── __init__.py│ └── datasource.py # DataSource, Reader, Writer classes└── tests/ ├── conftest.py # Spark session fixture ├── test_reader.py └── test_writer.pyPackage it as a wheel, install it on the cluster, and your notebooks can from my_rest_connector import RestApiDataSource before calling spark.dataSource.register().
Watch Out For
Section titled “Watch Out For”- Top-level imports of third-party libraries — your
read()andwrite()methods run on executors that don’t share the driver’s Python environment. Importrequests, database drivers, and SDK clients inside those methods, not at module level. - Deep inheritance hierarchies — PySpark serializes reader/writer instances with pickle to ship them to executors. Abstract base classes, multiple inheritance chains, and closures over unpicklable objects cause
PicklingErrorat runtime. Stick to flat, single-level inheritance. - Forgetting
partitions()— if your batch reader doesn’t overridepartitions(), Spark creates a single partition and reads everything sequentially on one executor. For any source with more than trivial data volume, implementpartitions()to enable parallel reads. - Yielding dicts instead of tuples from
read()— theread()method must yield tuples where each element maps positionally to the schema. Yielding dicts orRowobjects causes schema mismatch errors.