Skip to content

Production Patterns

Skill: spark-python-data-source

A data source that works in a notebook needs hardening before it runs unattended in production. Your AI coding assistant can generate the observability, security validation, resource lifecycle, and configuration validation patterns that separate a prototype from a connector you trust to run overnight.

“Harden my custom data source for production with rate limiting, monitoring, and graceful degradation. Track rows processed, failures, and throughput per partition. Use Python.”

import time
import json
import logging
from pyspark import TaskContext
logger = logging.getLogger(__name__)
class ProductionWriter:
def __init__(self, options, schema=None):
self.url = options.get("url")
self.batch_size = int(options.get("batch_size", "50"))
self.continue_on_error = options.get("continue_on_error", "false") == "true"
def write(self, iterator):
import requests
context = TaskContext.get()
partition_id = context.partitionId()
metrics = {
"partition_id": partition_id,
"rows_processed": 0,
"rows_failed": 0,
"start_time": time.time(),
}
try:
for row in iterator:
try:
requests.post(self.url, json=row.asDict()).raise_for_status()
metrics["rows_processed"] += 1
except Exception as e:
metrics["rows_failed"] += 1
if not self.continue_on_error:
raise
logger.warning(
"Row failed",
extra={"partition": partition_id, "error": str(e)},
)
finally:
metrics["duration_s"] = time.time() - metrics["start_time"]
if metrics["duration_s"] > 0:
metrics["rows_per_sec"] = (
metrics["rows_processed"] / metrics["duration_s"]
)
logger.info("Partition complete", extra={"metrics": json.dumps(metrics)})

Key decisions your AI coding assistant made:

  • Metrics in finally — the metrics block runs whether the partition succeeds or fails, so you always get throughput data for debugging.
  • continue_on_error is opt-in — defaulting to false means failures surface immediately. You explicitly enable partial-failure tolerance when the downstream use case allows it.
  • Structured loggingextra={"metrics": ...} produces parseable log lines for log aggregation systems. Avoid print() in production; it interleaves with Spark’s own output and can’t be filtered.
  • Per-partition granularity — each executor reports its own metrics. Aggregating across partitions on the driver gives you the full picture.

“Generate input validation for my data source that checks required options, validates numeric ranges, and rejects invalid identifiers before Spark starts any tasks.”

import re
class ValidatedDataSource:
def __init__(self, options):
self._validate(options)
self.options = options
def _validate(self, options):
errors = []
# Required options
for key in ("host", "database", "table"):
if key not in options:
errors.append(f"Missing required option: {key}")
# Numeric ranges
if "timeout" in options:
t = int(options["timeout"])
if not 0 <= t <= 300:
errors.append(f"timeout must be 0-300, got {t}")
if "batch_size" in options:
b = int(options["batch_size"])
if not 1 <= b <= 10000:
errors.append(f"batch_size must be 1-10000, got {b}")
# SQL identifier validation — prevent injection
if "table" in options:
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", options["table"]):
errors.append(
f"Invalid table name: {options['table']}. "
f"Must match [a-zA-Z_][a-zA-Z0-9_]*"
)
if errors:
raise ValueError(
"Configuration errors:\n" + "\n".join(f" - {e}" for e in errors)
)

Validation runs on the driver during __init__, before Spark distributes any work. Catching bad config here means fast, clear errors instead of cryptic failures on executors 10 minutes into the job.

Resource cleanup with guaranteed connection closing

Section titled “Resource cleanup with guaranteed connection closing”

“Add resource lifecycle management to my writer so database connections are always closed, even if the write fails halfway through.”

class ManagedWriter:
def __init__(self, options, schema=None):
self.options = options
self._connection = None
def write(self, iterator):
try:
self._connection = self._create_connection()
for row in iterator:
self._send(self._connection, row)
finally:
if self._connection:
try:
self._connection.close()
except Exception as e:
logger.warning(f"Error closing connection: {e}")
finally:
self._connection = None
def __del__(self):
"""Safety net — close connection if write() wasn't called."""
if self._connection:
self._connection.close()

Leaked connections accumulate across partitions and micro-batches. On a long-running streaming job, that’s a connection-pool exhaustion issue that appears hours after deployment.

“Add a health check to my data source that validates connectivity and authentication before Spark starts distributing tasks.”

class HealthCheckedDataSource:
def reader(self, schema):
self._check_health()
return MyReader(self.options, schema)
def _check_health(self):
import requests
checks = {}
# Connectivity
try:
response = requests.get(
f"{self.options['url']}/health", timeout=5
)
checks["connectivity"] = response.status_code == 200
except Exception:
checks["connectivity"] = False
# Authentication
try:
response = requests.get(
self.options["url"],
headers=self._get_auth_headers(),
timeout=5,
)
checks["authentication"] = response.status_code != 401
except Exception:
checks["authentication"] = False
failed = [k for k, v in checks.items() if not v]
if failed:
raise RuntimeError(
f"Health check failed: {', '.join(failed)}. "
f"Fix before running the job."
)

Health checks run on the driver before any executor work starts. A 5-second check here saves you from a 10-minute job that fails on every partition because the API is down.

“Add structured logging to my data source so I can correlate log lines across driver and executor processes for a single job.”

import logging
import json
class CorrelatedWriter:
def __init__(self, options, schema=None):
self.url = options.get("url")
self.job_id = options.get("job_id", "unknown")
def write(self, iterator):
import requests
from pyspark import TaskContext
context = TaskContext.get()
log_context = {
"job_id": self.job_id,
"partition_id": context.partitionId(),
"stage_id": context.stageId(),
}
logger.info(json.dumps({"event": "partition_start", **log_context}))
count = 0
for row in iterator:
requests.post(self.url, json=row.asDict()).raise_for_status()
count += 1
logger.info(json.dumps({
"event": "partition_complete",
"rows": count,
**log_context,
}))

Pass a job_id option when you start the job. Every log line across every executor includes it, so you can filter logs for a single run across hundreds of tasks.

  • Rate limiter per partition instead of global — if you create a rate limiter inside write(), each partition has its own. With 20 partitions, your “100 requests/second” limiter actually allows 2,000 requests/second. For global rate limiting, use the target API’s Retry-After headers or coordinate through a shared resource.
  • Logging credentialslogger.info(f"Connecting with options: {self.options}") dumps your API key into the log file. Sanitize options before logging (see the authentication patterns page for masking).
  • No connection timeoutrequests.get(url) without a timeout parameter waits indefinitely if the server accepts the TCP connection but never responds. Always set timeout= on every HTTP call.
  • Skipping validation because “it worked in the notebook” — production runs with different options, different cluster sizes, and different data volumes. Validate config in __init__, run health checks in reader()/writer(), and log enough context to debug failures without SSH access to the cluster.