Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/dispatch/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import secrets
from datetime import datetime, timedelta
from uuid import uuid4
from typing import Optional

import bcrypt
from jose import jwt
Expand Down Expand Up @@ -52,6 +53,7 @@ def hash_password(password: str):

class DispatchUser(Base, TimeStampMixin):
"""SQLAlchemy model for a Dispatch user."""

__table_args__ = {"schema": "dispatch_core"}

id = Column(Integer, primary_key=True)
Expand Down Expand Up @@ -104,6 +106,7 @@ def get_organization_role(self, organization_slug: OrganizationSlug):

class DispatchUserOrganization(Base, TimeStampMixin):
"""SQLAlchemy model for the relationship between users and organizations."""

__table_args__ = {"schema": "dispatch_core"}
dispatch_user_id = Column(Integer, ForeignKey(DispatchUser.id), primary_key=True)
dispatch_user = relationship(DispatchUser, backref="organizations")
Expand All @@ -116,6 +119,7 @@ class DispatchUserOrganization(Base, TimeStampMixin):

class DispatchUserProject(Base, TimeStampMixin):
"""SQLAlchemy model for the relationship between users and projects."""

dispatch_user_id = Column(Integer, ForeignKey(DispatchUser.id), primary_key=True)
dispatch_user = relationship(DispatchUser, backref="projects")

Expand All @@ -129,20 +133,23 @@ class DispatchUserProject(Base, TimeStampMixin):

class UserProject(DispatchBase):
"""Pydantic model for a user's project membership."""

project: ProjectRead
default: bool | None = False
role: str | None = None


class UserOrganization(DispatchBase):
"""Pydantic model for a user's organization membership."""

organization: OrganizationRead
default: bool | None = False
role: str | None = None


class UserBase(DispatchBase):
"""Base Pydantic model for user data."""

email: EmailStr
projects: list[UserProject] | None = []
organizations: list[UserOrganization] | None = []
Expand All @@ -158,6 +165,7 @@ def email_required(cls, v):

class UserLogin(UserBase):
"""Pydantic model for user login data."""

password: str

@field_validator("password")
Expand All @@ -171,6 +179,7 @@ def password_required(cls, v):

class UserRegister(UserLogin):
"""Pydantic model for user registration data."""

password: str = None

@field_validator("password", mode="before")
Expand All @@ -183,28 +192,32 @@ def password_required(cls, v):

class UserLoginResponse(DispatchBase):
"""Pydantic model for the response after user login."""

projects: list[UserProject] | None
token: str | None = None


class UserRead(UserBase):
"""Pydantic model for reading user data."""

id: PrimaryKey
role: str | None = None
experimental_features: bool | None


class UserUpdate(DispatchBase):
"""Pydantic model for updating user data."""

id: PrimaryKey
projects: list[UserProject] | None
organizations: list[UserOrganization] | None
experimental_features: bool | None
role: str | None = None
projects: Optional[list[UserProject]] = None
organizations: Optional[list[UserOrganization]]
experimental_features: Optional[bool] = None
role: Optional[str] = None


class UserPasswordUpdate(DispatchBase):
"""Pydantic model for password updates only."""

current_password: str
new_password: str

Expand All @@ -231,6 +244,7 @@ def password_required(cls, v):

class AdminPasswordReset(DispatchBase):
"""Pydantic model for admin password resets."""

new_password: str

@field_validator("new_password")
Expand All @@ -248,6 +262,7 @@ def validate_password(cls, v):

class UserCreate(DispatchBase):
"""Pydantic model for creating a new user."""

email: EmailStr
password: str | None = None
projects: list[UserProject] | None
Expand All @@ -263,16 +278,19 @@ def hash(cls, v):

class UserRegisterResponse(DispatchBase):
"""Pydantic model for the response after user registration."""

token: str | None = None


class UserPagination(Pagination):
"""Pydantic model for paginated user results."""

items: list[UserRead] = []


class MfaChallengeStatus(DispatchEnum):
"""Enumeration of possible MFA challenge statuses."""

APPROVED = "approved"
DENIED = "denied"
EXPIRED = "expired"
Expand All @@ -281,6 +299,7 @@ class MfaChallengeStatus(DispatchEnum):

class MfaChallenge(Base, TimeStampMixin):
"""SQLAlchemy model for an MFA challenge event."""

id = Column(Integer, primary_key=True, autoincrement=True)
valid = Column(Boolean, default=False)
reason = Column(String, nullable=True)
Expand All @@ -293,11 +312,13 @@ class MfaChallenge(Base, TimeStampMixin):

class MfaPayloadResponse(DispatchBase):
"""Pydantic model for the response to an MFA challenge payload."""

status: str


class MfaPayload(DispatchBase):
"""Pydantic model for an MFA challenge payload."""

action: str
project_id: int
challenge_id: str
8 changes: 6 additions & 2 deletions src/dispatch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,27 @@ def prompt_for_confirmation(command: str) -> bool:
f"Warning: You are about to {command} a remote database.",
fg="yellow",
)

database_name = click.prompt(f"Please enter the database name (env = {DATABASE_NAME})")
if database_name != DATABASE_NAME:
click.secho(
f"ERROR: You cannot {command} a database with a different name.",
fg="red",
)
return False
sqlalchemy_database_uri = f"postgresql+psycopg2://{config._DATABASE_CREDENTIAL_USER}:{config._QUOTED_DATABASE_PASSWORD}@{database_hostname}:{config.DATABASE_PORT}/{database_name}"

if command != "drop":
return True

sqlalchemy_database_uri = f"postgresql+psycopg2://{config._DATABASE_CREDENTIAL_USER}:{config._QUOTED_DATABASE_PASSWORD}@{database_hostname}:{config.DATABASE_PORT}/{database_name}"
if database_exists(str(sqlalchemy_database_uri)):
if click.confirm(
f"Are you sure you want to {command} database: '{database_hostname}:{database_name}'?"
):
return True
else:
click.secho(f"Database '{database_hostname}:{database_name}' does not exist!!!", fg="red")
return False
return False


@dispatch_database.command("init")
Expand Down
65 changes: 33 additions & 32 deletions src/dispatch/signal/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ def get_signal_engagement_by_name_or_raise(
)

if not signal_engagement:
raise ValidationError([
{
"msg": "Signal engagement not found.",
"loc": "signalEngagement",
}
])
raise ValidationError(
[
{
"msg": "Signal engagement not found.",
"loc": "signalEngagement",
}
]
)
return signal_engagement


Expand Down Expand Up @@ -254,12 +256,14 @@ def get_signal_filter_by_name_or_raise(
)

if not signal_filter:
raise ValidationError([
{
"msg": "Signal Filter not found.",
"loc": "signalFilter",
}
])
raise ValidationError(
[
{
"msg": "Signal Filter not found.",
"loc": "signalFilter",
}
]
)
return signal_filter


Expand Down Expand Up @@ -303,9 +307,7 @@ def get_default(*, db_session: Session, project_id: int) -> Signal | None:
)


def get_by_primary_or_external_id(
*, db_session: Session, signal_id: str | int
) -> Signal | None:
def get_by_primary_or_external_id(*, db_session: Session, signal_id: str | int) -> Signal | None:
"""Gets a signal by id or external_id."""
if is_valid_uuid(signal_id):
signal = db_session.query(Signal).filter(Signal.external_id == signal_id).one_or_none()
Expand Down Expand Up @@ -475,6 +477,7 @@ def update(
signal: Signal,
signal_in: SignalUpdate,
user: DispatchUser | None = None,
update_filters: bool = False,
) -> Signal:
"""Updates a signal."""
signal_data = signal.dict()
Expand Down Expand Up @@ -533,23 +536,21 @@ def update(
updates["engagements-removed"].append(se.name)
signal.engagements = engagements

is_filters_updated = {filter.id for filter in signal.filters} != {
filter.id for filter in signal_in.filters
}

if is_filters_updated:
filters = []
for f in signal_in.filters:
signal_filter = get_signal_filter_by_name_or_raise(
db_session=db_session, project_id=signal.project.id, signal_filter_in=f
)
if signal_filter not in signal.filters:
updates["filters-added"].append(signal_filter.name)
filters.append(signal_filter)
for f in signal.filters:
if f not in filters:
updates["filters-removed"].append(f.name)
signal.filters = filters
# if update_filters, use only the filters from the signal_in, otherwise use the existing filters and add new filters
filter_set = set() if update_filters else set(signal.filters)
for f in signal_in.filters:
signal_filter = get_signal_filter_by_name_or_raise(
db_session=db_session, project_id=signal.project.id, signal_filter_in=f
)
if signal_filter not in signal.filters:
updates["filters-added"].append(signal_filter.name)
filter_set.add(signal_filter)
elif update_filters:
filter_set.add(signal_filter)
for f in signal.filters:
if f not in filter_set:
updates["filters-removed"].append(f.name)
signal.filters = list(filter_set)

if signal_in.workflows:
workflows = []
Expand Down
Loading
Loading