Published on

Connect to a PostgreSQL database using PySpark

Table of Contents

Connect to a PostgreSQL database using PySpark

To connect to a PostgreSQL database using PySpark, you need to install the PostgreSQL JDBC driver and use the PySpark DataFrameReader to load data from the database.

Prerequisites and Setup

JDBC Driver Installation

The PostgreSQL JDBC driver is essential for establishing connectivity between Spark and PostgreSQL. Here are several ways to include it in your project:

Option 1: Download the JAR file

Download the PostgreSQL JDBC driver (postgresql-<version>.jar) from the official website: https://jdbc.postgresql.org/download/

For PostgreSQL 12 and above, use version 42.5.0 or later. Include it when starting PySpark:

pyspark --jars /path/to/postgresql-42.5.0.jar

Option 2: Use Maven coordinates (recommended)

pyspark --packages org.postgresql:postgresql:42.5.0

Option 3: Configure in SparkSession

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("PostgreSQL Connection") \
    .config("spark.jars.packages", "org.postgresql:postgresql:42.5.0") \
    .getOrCreate()

Version Compatibility

  • PostgreSQL 9.x - 15.x: Use JDBC driver 42.5.0+
  • Java 8+: Required for Spark and JDBC driver
  • PySpark 3.x: Recommended for latest features

Basic Connection Example

Now you can use PySpark to connect to the PostgreSQL database. Here's an example of how to do it:

from pyspark.sql import SparkSession

# Initialize the Spark session
spark = SparkSession.builder \
    .appName("PostgreSQL Connection Example") \
    .getOrCreate()

# Database connection properties
url = "jdbc:postgresql://localhost:5432/your_database_name"
properties = {
    "user": "your_user",
    "password": "your_password",
    "driver": "org.postgresql.Driver"
}

# Read data from the PostgreSQL table
table_name = "your_table_name"
df = spark.read \
    .jdbc(url, table_name, properties=properties)

# Show the DataFrame
df.show()

Make sure to replace your_database_name, your_user, your_password, and your_table_name with the appropriate values for your PostgreSQL database.

This example initializes a Spark session, sets the database connection properties, reads data from the specified PostgreSQL table, and shows the resulting DataFrame.

Advanced Connection Options

Connection Pooling

For production environments, enable connection pooling to improve performance:

properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "numPartitions": "10",
    "fetchsize": "1000",
    "batchsize": "1000"
}

SSL/TLS Configuration

Secure your database connections with SSL:

url = "jdbc:postgresql://localhost:5432/mydb?ssl=true&sslmode=require"
properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "sslmode": "require",
    "sslcert": "/path/to/client-cert.pem",
    "sslkey": "/path/to/client-key.pem",
    "sslrootcert": "/path/to/ca-cert.pem"
}

Using Custom Query

Read specific data using SQL queries:

query = "(SELECT * FROM employees WHERE salary > 50000) AS filtered_data"
df = spark.read \
    .jdbc(url, query, properties=properties)

Writing Data to PostgreSQL

Insert Operations

Write a DataFrame to PostgreSQL:

# Create sample DataFrame
data = [
    (1, "John Doe", 75000),
    (2, "Jane Smith", 85000),
    (3, "Bob Johnson", 65000)
]
columns = ["id", "name", "salary"]
df = spark.createDataFrame(data, columns)

# Write to PostgreSQL
df.write \
    .mode("append") \
    .jdbc(url, "employees", properties=properties)

Update and Upsert Operations

For upsert (insert or update), use the overwrite mode with a custom query:

# Overwrite specific partition
df.write \
    .mode("overwrite") \
    .option("truncate", "true") \
    .jdbc(url, "employees", properties=properties)

Batch Insert with Transaction Control

# Write in batches
df.write \
    .option("batchsize", "10000") \
    .option("isolationLevel", "READ_COMMITTED") \
    .mode("append") \
    .jdbc(url, "large_table", properties=properties)

Performance Optimization

Partitioning for Parallel Reads

Use column-based partitioning to read data in parallel:

df = spark.read \
    .jdbc(
        url=url,
        table="employees",
        column="id",
        lowerBound=1,
        upperBound=100000,
        numPartitions=10,
        properties=properties
    )

This creates 10 partitions, enabling parallel data loading.

Predicate Pushdown

Spark automatically pushes filters to PostgreSQL when possible:

# Filter is executed on PostgreSQL side
df = spark.read \
    .jdbc(url, "employees", properties=properties) \
    .filter("salary > 50000") \
    .filter("department = 'Engineering'")

Fetch Size Tuning

Optimize memory usage with fetch size:

properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "fetchsize": "10000"  # Rows fetched per round trip
}

Query Optimization

SQL Pushdown

Use SQL pushdown for complex queries:

# Complex query executed on database
query = """
(SELECT
    e.id,
    e.name,
    e.salary,
    d.department_name,
    AVG(e.salary) OVER (PARTITION BY e.department_id) as avg_dept_salary
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE e.hire_date > '2020-01-01') AS optimized_query
"""

df = spark.read.jdbc(url, query, properties=properties)

Column Pruning

Select only required columns:

# Only specified columns are retrieved
df = spark.read \
    .jdbc(url, "employees", properties=properties) \
    .select("id", "name", "salary")

Handling Large Datasets

Batch Processing Strategy

from pyspark.sql.functions import col

# Process large table in chunks
for partition_id in range(0, 100, 10):
    query = f"""
    (SELECT * FROM large_table
     WHERE id >= {partition_id} AND id < {partition_id + 10}) AS batch
    """

    df_batch = spark.read.jdbc(url, query, properties=properties)

    # Process batch
    df_batch.write \
        .mode("append") \
        .jdbc(url, "processed_table", properties=properties)

Memory Management

Configure Spark memory for large datasets:

spark = SparkSession.builder \
    .appName("Large Dataset Processing") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "8g") \
    .config("spark.memory.fraction", "0.8") \
    .getOrCreate()

Parallel Write Operations

# Repartition before writing for parallel inserts
df.repartition(20) \
    .write \
    .mode("append") \
    .jdbc(url, "target_table", properties=properties)

Error Handling and Reliability

Connection Retry Logic

from pyspark.sql.utils import AnalysisException
import time

def read_with_retry(url, table, properties, max_retries=3):
    for attempt in range(max_retries):
        try:
            df = spark.read.jdbc(url, table, properties=properties)
            return df
        except Exception as e:
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"Attempt {attempt + 1} failed. Retrying in {wait_time}s...")
                time.sleep(wait_time)
            else:
                raise e

df = read_with_retry(url, "employees", properties)

Transaction Management

properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "isolationLevel": "READ_COMMITTED",
    "sessionInitStatement": "SET TIME ZONE 'UTC'"
}

Timeout Configuration

properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "connectTimeout": "30",  # 30 seconds
    "socketTimeout": "60"     # 60 seconds
}

Security Best Practices

Using Environment Variables

Never hardcode credentials. Use environment variables:

import os

url = f"jdbc:postgresql://{os.getenv('DB_HOST')}:5432/{os.getenv('DB_NAME')}"
properties = {
    "user": os.getenv("DB_USER"),
    "password": os.getenv("DB_PASSWORD"),
    "driver": "org.postgresql.Driver"
}

Integration with Secret Managers

# Example with AWS Secrets Manager
import boto3
import json

def get_db_credentials(secret_name):
    client = boto3.client('secretsmanager', region_name='us-east-1')
    response = client.get_secret_value(SecretId=secret_name)
    return json.loads(response['SecretString'])

creds = get_db_credentials('prod/postgresql/credentials')

properties = {
    "user": creds['username'],
    "password": creds['password'],
    "driver": "org.postgresql.Driver"
}

Read-Only Connections

For read-only operations, use a read-only user:

properties = {
    "user": "readonly_user",
    "password": "readonly_password",
    "driver": "org.postgresql.Driver",
    "readOnly": "true"
}

Real-World ETL Pipeline Example

Here's a complete ETL pipeline that reads from one PostgreSQL database, transforms data, and writes to another:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, avg, current_timestamp
import os

# Initialize Spark
spark = SparkSession.builder \
    .appName("ETL Pipeline - PostgreSQL") \
    .config("spark.jars.packages", "org.postgresql:postgresql:42.5.0") \
    .getOrCreate()

# Source database configuration
source_url = "jdbc:postgresql://source-db:5432/sales_db"
source_props = {
    "user": os.getenv("SOURCE_DB_USER"),
    "password": os.getenv("SOURCE_DB_PASSWORD"),
    "driver": "org.postgresql.Driver",
    "fetchsize": "10000"
}

# Target database configuration
target_url = "jdbc:postgresql://target-db:5432/analytics_db"
target_props = {
    "user": os.getenv("TARGET_DB_USER"),
    "password": os.getenv("TARGET_DB_PASSWORD"),
    "driver": "org.postgresql.Driver",
    "batchsize": "10000"
}

# Extract: Read sales data
sales_df = spark.read \
    .jdbc(source_url, "sales_transactions", properties=source_props)

# Read customer data with partitioning
customer_df = spark.read \
    .jdbc(
        url=source_url,
        table="customers",
        column="customer_id",
        lowerBound=1,
        upperBound=1000000,
        numPartitions=10,
        properties=source_props
    )

# Transform: Join and aggregate
result_df = sales_df \
    .join(customer_df, "customer_id") \
    .filter(col("transaction_date") >= "2024-01-01") \
    .groupBy("customer_id", "customer_name", "region") \
    .agg(
        avg("amount").alias("avg_transaction_amount"),
        sum("amount").alias("total_sales")
    ) \
    .withColumn("processed_at", current_timestamp())

# Load: Write to target database
result_df.write \
    .mode("overwrite") \
    .option("truncate", "true") \
    .jdbc(target_url, "customer_analytics", properties=target_props)

print(f"ETL completed. Processed {result_df.count()} records.")

spark.stop()

Docker Setup for Local Development

Create a docker-compose.yml file for a complete local testing environment:

version: '3.8'

services:
  postgresql:
    image: postgres:15
    environment:
      POSTGRES_DB: testdb
      POSTGRES_USER: spark_user
      POSTGRES_PASSWORD: spark_password
    ports:
      - "5432:5432"
    volumes:
      - postgres_data:/var/lib/postgresql/data
      - ./init.sql:/docker-entrypoint-initdb.d/init.sql

  spark:
    image: bitnami/spark:3.5
    environment:
      - SPARK_MODE=master
      - SPARK_RPC_AUTHENTICATION_ENABLED=no
      - SPARK_RPC_ENCRYPTION_ENABLED=no
    ports:
      - "8080:8080"
      - "7077:7077"
    volumes:
      - ./scripts:/opt/spark-scripts

  spark-worker:
    image: bitnami/spark:3.5
    environment:
      - SPARK_MODE=worker
      - SPARK_MASTER_URL=spark://spark:7077
      - SPARK_WORKER_MEMORY=2G
      - SPARK_WORKER_CORES=2
    depends_on:
      - spark

volumes:
  postgres_data:

Create an init.sql file to initialize test data:

CREATE TABLE employees (
    id SERIAL PRIMARY KEY,
    name VARCHAR(100),
    email VARCHAR(100),
    department VARCHAR(50),
    salary NUMERIC(10, 2),
    hire_date DATE
);

INSERT INTO employees (name, email, department, salary, hire_date) VALUES
('John Doe', '[email protected]', 'Engineering', 75000, '2023-01-15'),
('Jane Smith', '[email protected]', 'Marketing', 65000, '2023-02-20'),
('Bob Johnson', '[email protected]', 'Engineering', 85000, '2023-03-10');

Run the environment:

docker-compose up -d

Test the connection:

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("Docker PostgreSQL Test") \
    .config("spark.jars.packages", "org.postgresql:postgresql:42.5.0") \
    .getOrCreate()

url = "jdbc:postgresql://localhost:5432/testdb"
properties = {
    "user": "spark_user",
    "password": "spark_password",
    "driver": "org.postgresql.Driver"
}

df = spark.read.jdbc(url, "employees", properties=properties)
df.show()

Troubleshooting Common Issues

Issue 1: ClassNotFoundException for PostgreSQL Driver

Error:

java.lang.ClassNotFoundException: org.postgresql.Driver

Solution: Ensure the JDBC driver is included in your classpath:

# Use packages option
pyspark --packages org.postgresql:postgresql:42.5.0

# Or specify JAR directly
pyspark --jars /path/to/postgresql-42.5.0.jar

Issue 2: Connection Refused

Error:

java.net.ConnectException: Connection refused

Solution:

  • Verify PostgreSQL is running: pg_isready -h localhost -p 5432
  • Check firewall settings
  • Ensure pg_hba.conf allows connections from your IP
  • Verify postgresql.conf has listen_addresses = '*'

Issue 3: Too Many Connections

Error:

FATAL: too many connections for role "user"

Solution: Reduce the number of partitions or increase PostgreSQL max_connections:

# Reduce partitions
properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "numPartitions": "5"  # Reduce from higher value
}

Or modify postgresql.conf:

max_connections = 200

Issue 4: Out of Memory Errors

Error:

java.lang.OutOfMemoryError: Java heap space

Solution: Adjust fetch size and Spark memory settings:

properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "fetchsize": "1000"  # Reduce from default
}

# Increase Spark memory
spark = SparkSession.builder \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "8g") \
    .getOrCreate()

Issue 5: Slow Query Performance

Symptoms: Queries taking longer than expected

Solutions:

  1. Enable predicate pushdown verification:
df = spark.read.jdbc(url, "large_table", properties=properties)
df.filter("id > 1000").explain(True)  # Check if filter is pushed down
  1. Use partitioned reads:
df = spark.read \
    .jdbc(
        url=url,
        table="large_table",
        column="id",
        lowerBound=1,
        upperBound=1000000,
        numPartitions=20,
        properties=properties
    )
  1. Create indexes on partition columns in PostgreSQL:
CREATE INDEX idx_large_table_id ON large_table(id);

Issue 6: SSL/TLS Connection Failures

Error:

SSL error: certificate verify failed

Solution: Configure SSL properly or disable for development:

# For development only
url = "jdbc:postgresql://localhost:5432/mydb?ssl=false"

# For production with proper SSL
url = "jdbc:postgresql://localhost:5432/mydb?ssl=true&sslmode=verify-full"
properties = {
    "user": "postgres",
    "password": "password",
    "driver": "org.postgresql.Driver",
    "sslrootcert": "/path/to/root.crt"
}

Related Articles