Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .sampo/changesets/ardent-knight-vainamoinen.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
pypi/posthog: minor
---

feat: add Celery integration and improve PostHog client fork safety
185 changes: 185 additions & 0 deletions examples/celery_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
Celery integration example for the PostHog Python SDK.

Demonstrates how to use ``PosthogCeleryIntegration`` with:
- producer-side and worker-side instrumentation (publishing events and context propagation)
- context propagation (distinct ID, session ID, tags) from producer to worker
- task lifecycle events (published, started, success, failure, retry)
- exception capture from failed tasks
- ``task_filter`` customization hook

Setup:
1. Set ``POSTHOG_PROJECT_API_KEY`` and ``POSTHOG_HOST`` in your environment
2. Install dependencies: pip install posthog celery redis
3. Start Redis: redis-server
4. Start the worker: celery -A examples.celery_integration worker --loglevel=info
5. Run the producer: python -m examples.celery_integration
"""

import os
import time
from typing import Any, Optional

from celery import Celery
from celery.signals import worker_process_init, worker_process_shutdown

import posthog
from posthog.integrations.celery import PosthogCeleryIntegration


# --- Configuration ---

POSTHOG_PROJECT_API_KEY = os.getenv("POSTHOG_PROJECT_API_KEY", "phc_...")
POSTHOG_HOST = os.getenv("POSTHOG_HOST", "http://localhost:8000")

app = Celery(
"examples.celery_integration",
broker="redis://localhost:6379/0",
)


# --- Integration wiring ---

def configure_posthog() -> None:
posthog.api_key = POSTHOG_PROJECT_API_KEY
posthog.host = POSTHOG_HOST
posthog.enable_local_evaluation = False # to not require personal_api_key for this example
posthog.setup()


def task_filter(task_name: Optional[str], task_properties: dict[str, Any]) -> bool:
if task_name is not None and task_name.endswith(".health_check"):
return False
return True


def create_integration() -> PosthogCeleryIntegration:
return PosthogCeleryIntegration(
capture_exceptions=True,
capture_task_lifecycle_events=True,
propagate_context=True,
task_filter=task_filter,
)

configure_posthog()
integration = create_integration()
integration.instrument()


# --- Worker process setup ---
# Celery's default prefork pool runs tasks in child processes. This example
# runs on a single host, so the inherited PostHog client and Celery
# integration are fork-safe and do not need to be recreated in each child.
# If workers run across multiple hosts, configure PostHog and instrument a
# worker-local integration in worker_process_init.
@worker_process_init.connect
def on_worker_process_init(**kwargs) -> None:
# global integration

# configure_posthog()
# integration = create_integration()
# integration.instrument()
return
Comment on lines +69 to +82
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still necessary? It looks like it's a no-op

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this shows how to set it up if the workers are running on a different host. I'll comment out this handler entirely and make that clearer



# Use this signal to shutdown the integration and PostHog client
# Calling shutdown() is important to flush any pending events
@worker_process_shutdown.connect
def on_worker_process_shutdown(**kwargs) -> None:
integration.shutdown()
posthog.shutdown()
Comment on lines +85 to +90
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required setup configuration or can this be omitted? We're also calling doing this at the bottom of the example, but it's not clear why or if it's required.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for the celery workers. The one at the bottom of the example does it for the producer process. I'll add comments explaining it



# --- Example tasks ---

@app.task
def health_check() -> dict[str, str]:
return {"status": "ok"}


@app.task(max_retries=3)
def process_order(order_id: str) -> dict:
"""A task that processes an order successfully."""

# simulate work
time.sleep(0.1)

# Custom event inside the task - context tags propagated from the
# producer (e.g. "source", "release") should appear on this event
# and this should be attributed to the correct distinct ID and session.
posthog.capture(
"celery example order processed",
properties={"order_id": order_id, "amount": 99.99},
)

return {"order_id": order_id, "status": "completed"}


@app.task(bind=True, max_retries=3)
def send_notification(self, user_id: str, message: str) -> None:
"""A task that may fail and retry."""
if self.request.retries < 2:
raise self.retry(
exc=ConnectionError("notification service unavailable"),
countdown=120,
)
return None


@app.task
def failing_task() -> None:
"""A task that always fails."""
raise ValueError("something went wrong")


# --- Producer code ---

if __name__ == "__main__":
print("PostHog Celery Integration Example")
print("=" * 40)
print()

# Set up PostHog context before dispatching tasks.
# The integration propagates this context to workers via task headers.
with posthog.new_context(fresh=True):
posthog.identify_context("user-123")
posthog.set_context_session("session-user-123-abc")
posthog.tag("source", "celery_integration_example_script")
posthog.tag("release", "v1.2.3")

print("Dispatching tasks...")

# This task is intentionally filtered and should not emit task events.
result = health_check.delay()
print(f" health_check dispatched (filtered): {result.id}")

# This task will produce events:
# celery task published (sender side)
# celery task started (worker side)
# order processed (custom event, should carry propagated context tags)
# celery task success (worker side, includes duration)
result = process_order.delay("order-456")
print(f" process_order dispatched: {result.id}")

# This task will produce events:
# celery task published
# celery task started
# celery task retry (with reason)
# celery task started (retry attempt)
# celery task success
result = send_notification.delay("user-123", "Hello!")
print(f" send_notification dispatched: {result.id}")

# This task will produce events:
# celery task published
# celery task started
# celery task failure (with error_type and error_message)
result = failing_task.delay()
print(f" failing_task dispatched: {result.id}")

print()
print("Tasks dispatched. Check your Celery worker logs and PostHog for events.")
print()

integration.shutdown()
posthog.shutdown()
71 changes: 71 additions & 0 deletions posthog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import warnings
import weakref
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Union
from uuid import uuid4
Expand Down Expand Up @@ -56,6 +57,7 @@
flags,
get,
remote_config,
reset_sessions,
)
from posthog.types import (
FeatureFlag,
Expand Down Expand Up @@ -219,6 +221,7 @@ def __init__(
Category:
Initialization
"""
self._max_queue_size = max_queue_size
self.queue = queue.Queue(max_queue_size)

# api_key: This should be the Team API Key (token), public
Expand All @@ -243,6 +246,7 @@ def __init__(
)
self.poller = None
self.distinct_ids_feature_flags_reported = SizeLimitedDict(MAX_DICT_SIZE, set)
self.flag_fallback_cache_url = flag_fallback_cache_url
self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url)
self.flag_definition_version = 0
self._flags_etag: Optional[str] = None
Expand Down Expand Up @@ -332,6 +336,10 @@ def __init__(
if send:
consumer.start()

if hasattr(os, "register_at_fork"):
weak_self = weakref.ref(self)
os.register_at_fork(after_in_child=lambda: Client._reinit_after_fork_weak(weak_self))

def new_context(self, fresh=False, capture_exceptions=True):
"""
Create a new context for managing shared state. Learn more about [contexts](/docs/libraries/python#contexts).
Expand Down Expand Up @@ -1080,6 +1088,69 @@ def capture_exception(
except Exception as e:
self.log.exception(f"Failed to capture exception: {e}")

@staticmethod
def _reinit_after_fork_weak(weak_self):
"""
Reinitialize the client after a fork.
Garbage collected if the client is deleted.
"""
self = weak_self()
if self is None:
return
self._reinit_after_fork()

def _reinit_after_fork(self):
"""Reinitialize fork-unsafe client state in a forked child process.

Registered via os.register_at_fork(after_in_child=...) so it runs
exactly once in each child, before any user code, covering all code
paths (capture, flush, join, etc.).

Python threads do not survive fork() and queue.Queue internal locks
may be in an inconsistent state, so the event queue, consumer threads
and other state are replaced. Inherited queue items are not retained
as they'll be handled by the parent process's consumers.
"""
Comment on lines +1102 to +1113
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One additional consideration here is that urllib3 sessions contain sockets that will be shared after fork. They're not mutex controlled, so they should should also be recreated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, turned out that Poller and RedisFlagCache also needed to be recreated so I've updated the client to handle them all.

Not sure what to do about flag_definition_cache_provider. It could be problematic due to locks but it's not in our control. Do we just clarify in the docs that we recommend recreating the client in that case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little torn on this, because it feels a bit like the Celery integration and flag_definition_cache_provider are at odds with each other, unless we also expand the FlagDefinitionCacheProvider to handle _reinit_after_fork. In practice, Celery is a great use case for caching flag definitions.

Perhaps we could do something like:

integration = PosthogCeleryIntegration(
    client_factory=create_posthog_client,
)

where:

def create_posthog_client():
    cache = MyCacheProvider(...)
    return Posthog(
        "...",
        personal_api_key="...",
        flag_definition_cache_provider=cache,
    )

Then on worker_process_init, the integration could call the factory and store the client for that process and we avoid the whole fork.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more inclined towards updating the docs/example to explicitly mention that users with a custom flag_definition_cache_provider should reinit their client in the workers and reinstrument the integration with that new client instance. That's more in line with what the other providers (Sentry, DD, OTel) recommend.

We can replacing client with factory but that would be a deviation from that standard practice which I guess users would be more likely to be used to.

if self.consumers:
self.queue = queue.Queue(self._max_queue_size)

new_consumers = []
for old in self.consumers:
consumer = Consumer(
self.queue,
old.api_key,
flush_at=old.flush_at,
host=old.host,
on_error=old.on_error,
flush_interval=old.flush_interval,
gzip=old.gzip,
retries=old.retries,
timeout=old.timeout,
historical_migration=old.historical_migration,
)
new_consumers.append(consumer)

if self.send:
consumer.start()

self.consumers = new_consumers

if self.enable_local_evaluation:
self.poller = Poller(
interval=timedelta(seconds=self.poll_interval),
execute=self._load_feature_flags,
)
self.poller.start()
else:
self.poller = None

# If using Redis cache, we must reinitialize to get a fresh connection (fork-safe).
# If using Memory cache, we keep it as-is to benefit from the inherited warm cache.
if isinstance(self.flag_cache, RedisFlagCache):
self.flag_cache = self._initialize_flag_cache(self.flag_fallback_cache_url)

reset_sessions()

def _enqueue(self, msg, disable_geoip):
# type: (...) -> Optional[str]
"""Push a new `msg` onto the queue, return `(success, msg)`"""
Expand Down
Loading