Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions alembic/versions/33ae457b2ddf_add_referral_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Create Date: 2025-12-26 10:37:46.325765

"""

from typing import Sequence, Union

from alembic import op
Expand All @@ -13,26 +14,28 @@
from sqlalchemy.ext.declarative import declarative_base

# revision identifiers, used by Alembic.
revision: str = '33ae457b2ddf'
down_revision: Union[str, Sequence[str], None] = '8b9c2e1f4c1c'
revision: str = "33ae457b2ddf"
down_revision: Union[str, Sequence[str], None] = "8b9c2e1f4c1c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

# Define a minimal model for data migration
Base = declarative_base()


class Profile(Base):
__tablename__ = 'profiles'
__tablename__ = "profiles"
user_id = sa.Column(sa.UUID, primary_key=True)
referral_code = sa.Column(sa.String)
referral_count = sa.Column(sa.Integer)


def upgrade() -> None:
"""Upgrade schema."""
# 1. Add columns as nullable first
op.add_column('profiles', sa.Column('referral_code', sa.String(), nullable=True))
op.add_column('profiles', sa.Column('referrer_id', sa.UUID(), nullable=True))
op.add_column('profiles', sa.Column('referral_count', sa.Integer(), nullable=True))
op.add_column("profiles", sa.Column("referral_code", sa.String(), nullable=True))
op.add_column("profiles", sa.Column("referrer_id", sa.UUID(), nullable=True))
op.add_column("profiles", sa.Column("referral_count", sa.Integer(), nullable=True))

# 2. Backfill existing rows with 0 count
bind = op.get_bind()
Expand All @@ -45,10 +48,12 @@ def upgrade() -> None:
# 3. Alter columns
# referral_code stays nullable=True
# referral_count becomes nullable=False
op.alter_column('profiles', 'referral_count', nullable=False)
op.alter_column("profiles", "referral_count", nullable=False)

# 4. Create unique constraint and index
op.create_unique_constraint("uq_profiles_referral_code", "profiles", ["referral_code"])
op.create_unique_constraint(
"uq_profiles_referral_code", "profiles", ["referral_code"]
)
op.create_index("ix_profiles_referral_code", "profiles", ["referral_code"])

# Add foreign key for referrer_id
Expand All @@ -62,6 +67,6 @@ def downgrade() -> None:
op.drop_constraint("fk_profiles_referrer_id", "profiles", type_="foreignkey")
op.drop_index("ix_profiles_referral_code", table_name="profiles")
op.drop_constraint("uq_profiles_referral_code", "profiles", type_="unique")
op.drop_column('profiles', 'referral_count')
op.drop_column('profiles', 'referrer_id')
op.drop_column('profiles', 'referral_code')
op.drop_column("profiles", "referral_count")
op.drop_column("profiles", "referrer_id")
op.drop_column("profiles", "referral_code")
14 changes: 8 additions & 6 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,21 @@ class PaymentRetryConfig(BaseModel):
max_attempts: int


class ReferralConfig(BaseModel):
"""Referral program configuration."""

referrals_required: int
reward_months: int


class SubscriptionConfig(BaseModel):
"""Subscription configuration."""

stripe: SubscriptionStripeConfig
metered: MeteredConfig
trial_period_days: int
payment_retry: PaymentRetryConfig
referral: ReferralConfig


class StripeWebhookConfig(BaseModel):
Expand All @@ -161,9 +169,3 @@ class TelegramConfig(BaseModel):
"""Telegram configuration."""

chat_ids: TelegramChatIdsConfig


class ServerConfig(BaseModel):
"""Server configuration."""

allowed_origins: list[str]
2 changes: 0 additions & 2 deletions common/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
SubscriptionConfig,
StripeConfig,
TelegramConfig,
ServerConfig,
)
from common.db_uri_resolver import resolve_db_uri

Expand Down Expand Up @@ -152,7 +151,6 @@ class Config(BaseSettings):
subscription: SubscriptionConfig
stripe: StripeConfig
telegram: TelegramConfig
server: ServerConfig

# Environment variables (required)
DEV_ENV: str
Expand Down
11 changes: 4 additions & 7 deletions common/global_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ subscription:
trial_period_days: 7
payment_retry:
max_attempts: 3
# Referral program configuration
referral:
referrals_required: 5
reward_months: 6

########################################################
# Stripe
Expand All @@ -108,10 +112,3 @@ telegram:
chat_ids:
admin_alerts: "1560836485"
test: "1560836485"

########################################################
# Server
########################################################
server:
allowed_origins:
- "http://localhost:8080"
33 changes: 33 additions & 0 deletions docs/REFERRALS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Referral Program

The application includes a referral system that rewards users for inviting others to the platform.

## Incentive: Refer 5, Get 6 Months Free (Default)

When a user successfully refers a specific number of new users (default: **5**), they are automatically rewarded with a period of the Plus Tier subscription for free (default: **6 months**).

### Configuration

The referral program parameters are configurable in `common/global_config.yaml` under the `subscription.referral` section:

```yaml
subscription:
referral:
referrals_required: 5
reward_months: 6
```

### How it works

1. **Referral Code**: Each user has a unique referral code.
2. **Invitation**: Users share their code with potential new users.
3. **Redemption**: When a new user signs up (or enters the code in their settings), they apply the referral code.
4. **Tracking**: The system tracks the number of successful referrals for each referrer.
5. **Reward Trigger**:
* Once the referrer's count reaches the configured `referrals_required`, the system automatically grants the reward.
* **New Subscription**: If the referrer is on the Free tier, they are upgraded to the Plus tier for `reward_months`.
* **Existing Subscription**: If the referrer already has a Plus tier subscription, their subscription end date is extended by `reward_months`.

### Technical Implementation

The logic is handled in `src/api/services/referral_service.py` within the `apply_referral` method. When the referral count increments to the configured threshold, the `grant_referral_reward` method is called to update the `UserSubscriptions` table.
14 changes: 2 additions & 12 deletions src/api/auth/workos_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,8 @@ async def get_current_workos_user(request: Request) -> WorkOSUser:
token = auth_header.split(" ", 1)[1]

# Check if we're in test mode (skip signature verification for tests)
# Detect test mode by checking if pytest is running or if DEV_ENV is explicitly set to "test"
# We also check for 'test' in sys.argv[0] ONLY if we are NOT in production, to avoid security risks
# where a script named "test_something.py" could bypass auth in prod.
is_pytest = "pytest" in sys.modules
is_dev_env_test = global_config.DEV_ENV.lower() == "test"

# Only check sys.argv if we are definitely not in prod
is_script_test = False
if global_config.DEV_ENV.lower() != "prod":
is_script_test = "test" in sys.argv[0].lower()

is_test_mode = is_pytest or is_dev_env_test or is_script_test
# Detect test mode by checking if pytest is running
is_test_mode = "pytest" in sys.modules or "test" in sys.argv[0].lower()

# Determine whether the token declares an audience so we can decide
# whether to enforce audience verification (access tokens currently omit aud).
Expand Down
38 changes: 0 additions & 38 deletions src/api/routes/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,6 @@ class ConversationPayload(BaseModel):
conversation: list[ConversationMessage]


class AgentLimitResponse(BaseModel):
"""Response model for agent limit status."""

tier: str
limit_name: str
limit_value: int
used_today: int
remaining: int
reset_at: datetime


class AgentResponse(BaseModel):
"""Response model for agent endpoint."""

Expand Down Expand Up @@ -314,33 +303,6 @@ def record_agent_message(
return message


@router.get("/agent/limits", response_model=AgentLimitResponse)
async def get_agent_limits(
request: Request,
db: Session = Depends(get_db_session),
) -> AgentLimitResponse:
"""
Get the current user's agent limit status.

Returns usage statistics for the daily agent chat limit, including
current tier, usage count, remaining quota, and reset time.
"""
auth_user = await get_authenticated_user(request, db)
user_id = auth_user.id
user_uuid = user_uuid_from_str(user_id)

limit_status = ensure_daily_limit(db=db, user_uuid=user_uuid, enforce=False)

return AgentLimitResponse(
tier=limit_status.tier,
limit_name=limit_status.limit_name,
limit_value=limit_status.limit_value,
used_today=limit_status.used_today,
remaining=limit_status.remaining,
reset_at=limit_status.reset_at,
)


@router.post("/agent", response_model=AgentResponse) # noqa
@observe()
async def agent_endpoint(
Expand Down
11 changes: 1 addition & 10 deletions src/api/routes/agent/tools/alert_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,8 @@ def alert_admin(
telegram = Telegram()
# Use test chat during testing to avoid spamming production alerts
import sys
from common import global_config

is_pytest = "pytest" in sys.modules
is_dev_env_test = global_config.DEV_ENV.lower() == "test"

# Only check sys.argv if we are definitely not in prod
is_script_test = False
if global_config.DEV_ENV.lower() != "prod":
is_script_test = "test" in sys.argv[0].lower()

is_testing = is_pytest or is_dev_env_test or is_script_test
is_testing = "pytest" in sys.modules or "test" in sys.argv[0].lower()
chat_name = "test" if is_testing else "admin_alerts"

message_id = telegram.send_message_to_chat(
Expand Down
62 changes: 62 additions & 0 deletions src/api/services/referral_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
from sqlalchemy.exc import IntegrityError
from src.db.models.public.profiles import Profiles, generate_referral_code
from src.db.utils.db_transaction import db_transaction
from src.db.models.stripe.user_subscriptions import UserSubscriptions
from src.db.models.stripe.subscription_types import SubscriptionTier
from common.global_config import global_config
from datetime import datetime, timedelta, timezone
from loguru import logger
from typing import cast
import uuid


class ReferralService:
Expand All @@ -18,6 +25,52 @@ def validate_referral_code(
db.query(Profiles).filter(Profiles.referral_code == referral_code).first()
)

@staticmethod
def grant_referral_reward(db: Session, user_id: uuid.UUID):
"""
Grant Plus Tier to the user based on configured reward duration.
"""
now = datetime.now(timezone.utc)
reward_months = global_config.subscription.referral.reward_months
reward_duration = timedelta(days=30 * reward_months)

subscription = (
db.query(UserSubscriptions)
.filter(UserSubscriptions.user_id == user_id)
.first()
)

if subscription:
subscription.subscription_tier = SubscriptionTier.PLUS.value
subscription.is_active = True

# If current subscription is valid and ends in the future, extend it
# Otherwise start from now
current_end = subscription.subscription_end_date
if current_end and current_end.tzinfo is None:
current_end = current_end.replace(tzinfo=timezone.utc)

if current_end and current_end > now:
subscription.subscription_end_date = current_end + reward_duration
else:
subscription.subscription_end_date = now + reward_duration

logger.info(
f"Updated subscription for user {user_id} via referral reward ({reward_months} months)"
)
else:
new_subscription = UserSubscriptions(
user_id=user_id,
subscription_tier=SubscriptionTier.PLUS.value,
is_active=True,
subscription_start_date=now,
subscription_end_date=now + reward_duration,
)
db.add(new_subscription)
logger.info(
f"Created subscription for user {user_id} via referral reward ({reward_months} months)"
)
Comment on lines +29 to +72
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Function modifies database but doesn't handle its own transaction. Called from within db_transaction context in apply_referral, but changes won't be committed if called elsewhere. Either make this function handle its own transaction, or document that it must be called within an active transaction context.

Prompt To Fix With AI
This is a comment left during a code review.
Path: src/api/services/referral_service.py
Line: 29:72

Comment:
**logic:** Function modifies database but doesn't handle its own transaction. Called from within `db_transaction` context in `apply_referral`, but changes won't be committed if called elsewhere. Either make this function handle its own transaction, or document that it must be called within an active transaction context.

How can I resolve this? If you propose a fix, please make it concise.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by this? Can you explain?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that grant_referral_reward modifies the database (lines 44-45, 54/56, and 69) but doesn't handle database transactions itself. Let me explain the problem:

Current situation:

  • grant_referral_reward makes database changes but has no db.commit() or transaction management
  • It's called from apply_referral (line 109) which wraps its work in db_transaction
  • If someone calls grant_referral_reward directly from elsewhere, the changes won't be saved to the database

The problem:

# This would NOT save changes to the database:
ReferralService.grant_referral_reward(db, user_id)
# Changes are lost because no commit happened

Two solutions:

Option 1: Make it handle its own transaction

Suggested change
def grant_referral_reward(db: Session, user_id: uuid.UUID):
"""
Grant Plus Tier to the user based on configured reward duration.
"""
now = datetime.now(timezone.utc)
reward_months = global_config.subscription.referral.reward_months
reward_duration = timedelta(days=30 * reward_months)
subscription = (
db.query(UserSubscriptions)
.filter(UserSubscriptions.user_id == user_id)
.first()
)
if subscription:
subscription.subscription_tier = SubscriptionTier.PLUS.value
subscription.is_active = True
# If current subscription is valid and ends in the future, extend it
# Otherwise start from now
current_end = subscription.subscription_end_date
if current_end and current_end.tzinfo is None:
current_end = current_end.replace(tzinfo=timezone.utc)
if current_end and current_end > now:
subscription.subscription_end_date = current_end + reward_duration
else:
subscription.subscription_end_date = now + reward_duration
logger.info(
f"Updated subscription for user {user_id} via referral reward ({reward_months} months)"
)
else:
new_subscription = UserSubscriptions(
user_id=user_id,
subscription_tier=SubscriptionTier.PLUS.value,
is_active=True,
subscription_start_date=now,
subscription_end_date=now + reward_duration,
)
db.add(new_subscription)
logger.info(
f"Created subscription for user {user_id} via referral reward ({reward_months} months)"
)
@staticmethod
def grant_referral_reward(db: Session, user_id: uuid.UUID):
"""
Grant Plus Tier to the user based on configured reward duration.
"""
with db_transaction(db):
now = datetime.now(timezone.utc)
reward_months = global_config.subscription.referral.reward_months
reward_duration = timedelta(days=30 * reward_months)
subscription = (
db.query(UserSubscriptions)
.filter(UserSubscriptions.user_id == user_id)
.first()
)
if subscription:
subscription.subscription_tier = SubscriptionTier.PLUS.value
subscription.is_active = True
# If current subscription is valid and ends in the future, extend it
# Otherwise start from now
current_end = subscription.subscription_end_date
if current_end and current_end.tzinfo is None:
current_end = current_end.replace(tzinfo=timezone.utc)
if current_end and current_end > now:
subscription.subscription_end_date = current_end + reward_duration
else:
subscription.subscription_end_date = now + reward_duration
logger.info(
f"Updated subscription for user {user_id} via referral reward ({reward_months} months)"
)
else:
new_subscription = UserSubscriptions(
user_id=user_id,
subscription_tier=SubscriptionTier.PLUS.value,
is_active=True,
subscription_start_date=now,
subscription_end_date=now + reward_duration,
)
db.add(new_subscription)
logger.info(
f"Created subscription for user {user_id} via referral reward ({reward_months} months)"
)

Option 2: Document the requirement
Add to the docstring: "Must be called within an active database transaction context."

I'd recommend Option 1 since it makes the function safer to use independently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic determines the start and end dates of the free period:

  1. Extension: If the user already has an active subscription ending in the future (e.g., ends on Dec 31st), we add the reward duration to that existing date (New End Date = Dec 31st + 6 months). This ensures they don't lose the time they already paid for.
  2. Immediate Start: If they don't have an active subscription (or it expired in the past), the 6-month free period starts immediately from today (now).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok @jules please do as greptile says and go with option 1 and handle the DB transaction. (unless you disagree. if you disagree pls say)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize, but I cannot see the comment from @greptile to reference 'Option 1'. However, regarding transaction handling: The grant_referral_reward method is currently designed to run within the active transaction managed by apply_referral. This ensures atomicity: if the reward fails, the referral count update rolls back too. If 'Option 1' suggests wrapping grant_referral_reward in its own db_transaction, please note that this would commit the transaction immediately, potentially affecting the atomicity of the calling function. I believe the current approach is the safest for data integrity. Please let me know if you'd like me to proceed differently!


@staticmethod
def apply_referral(db: Session, user_profile: Profiles, referral_code: str) -> bool:
"""
Expand Down Expand Up @@ -46,6 +99,15 @@ def apply_referral(db: Session, user_profile: Profiles, referral_code: str) -> b

db.add(user_profile)

# Refresh referrer to get updated count and trigger reward if applicable
db.refresh(referrer)

required_referrals = global_config.subscription.referral.referrals_required
if referrer.referral_count == required_referrals:
# Cast user_id to uuid.UUID to satisfy ty
user_id = cast(uuid.UUID, referrer.user_id)
ReferralService.grant_referral_reward(db, user_id)

db.refresh(user_profile)
return True

Expand Down
5 changes: 3 additions & 2 deletions src/db/utils/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import uuid
from loguru import logger


def ensure_profile_exists(
db: Session,
user_uuid: uuid.UUID,
email: str | None = None,
username: str | None = None,
avatar_url: str | None = None,
is_approved: bool = False
is_approved: bool = False,
) -> Profiles:
"""
Ensure a profile exists for the given user UUID.
Expand All @@ -27,7 +28,7 @@ def ensure_profile_exists(
email=email,
username=username,
avatar_url=avatar_url,
is_approved=is_approved
is_approved=is_approved,
)
db.add(profile)
# No need for explicit commit/refresh as db_transaction handles commit,
Expand Down
6 changes: 4 additions & 2 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# Add CORS middleware with specific allowed origins
app.add_middleware( # type: ignore[call-overload]
CORSMiddleware, # type: ignore[arg-type]
allow_origins=global_config.server.allowed_origins,
allow_origins=[
"http://localhost:8080",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand Down Expand Up @@ -52,5 +54,5 @@ def include_all_routers():
host="0.0.0.0",
port=int(os.getenv("PORT", 8080)),
log_config=None, # Disable uvicorn's logging config
access_log=True, # Enable access logs
access_log=False, # Disable access logs
)
Loading
Loading