Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions server/api/views/ai_promptStorage/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from drf_spectacular.utils import extend_schema
from .models import AI_PromptStorage
from .serializers import AI_PromptStorageSerializer


@extend_schema(request=AI_PromptStorageSerializer, responses={201: AI_PromptStorageSerializer})
@api_view(['POST'])
# @permission_classes([IsAuthenticated])
def store_prompt(request):
Expand All @@ -21,6 +23,7 @@ def store_prompt(request):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)


@extend_schema(responses={200: AI_PromptStorageSerializer(many=True)})
@api_view(['GET'])
def get_all_prompts(request):
"""
Expand Down
2 changes: 2 additions & 0 deletions server/api/views/ai_settings/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from drf_spectacular.utils import extend_schema
from .models import AI_Settings
from .serializers import AISettingsSerializer


@extend_schema(request=AISettingsSerializer, responses={200: AISettingsSerializer(many=True), 201: AISettingsSerializer})
@api_view(['GET', 'POST'])
@permission_classes([IsAuthenticated])
def settings_view(request):
Expand Down
17 changes: 17 additions & 0 deletions server/api/views/assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rest_framework.permissions import AllowAny
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework import serializers as drf_serializers

from openai import OpenAI

Expand Down Expand Up @@ -113,6 +115,21 @@ def invoke_functions_from_response(
class Assistant(APIView):
permission_classes = [AllowAny]

@extend_schema(
request=inline_serializer(name='AssistantRequest', fields={
'message': drf_serializers.CharField(help_text='User message to send to the assistant'),
'previous_response_id': drf_serializers.CharField(required=False, allow_null=True, help_text='ID of previous response for conversation continuity'),
}),
responses={
200: inline_serializer(name='AssistantResponse', fields={
'response_output_text': drf_serializers.CharField(),
'final_response_id': drf_serializers.CharField(),
}),
500: inline_serializer(name='AssistantError', fields={
'error': drf_serializers.CharField(),
}),
}
)
def post(self, request):
try:
user = request.user
Expand Down
31 changes: 31 additions & 0 deletions server/api/views/conversations/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .models import Conversation, Message
from .serializers import ConversationSerializer
from ...services.tools.tools import tools, execute_tool
from drf_spectacular.utils import extend_schema, inline_serializer
from rest_framework import serializers as drf_serializers


@csrf_exempt
Expand Down Expand Up @@ -95,6 +97,21 @@ def destroy(self, request, *args, **kwargs):
self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT)

@extend_schema(
request=inline_serializer(name='ContinueConversationRequest', fields={
'message': drf_serializers.CharField(help_text='User message to continue the conversation'),
'page_context': drf_serializers.CharField(required=False, help_text='Optional page context'),
}),
responses={
200: inline_serializer(name='ContinueConversationResponse', fields={
'response': drf_serializers.CharField(),
'title': drf_serializers.CharField(),
}),
400: inline_serializer(name='ContinueConversationBadRequest', fields={
'error': drf_serializers.CharField(),
}),
}
)
@action(detail=True, methods=['post'])
def continue_conversation(self, request, pk=None):
conversation = self.get_object()
Expand Down Expand Up @@ -123,6 +140,20 @@ def continue_conversation(self, request, pk=None):

return Response({"response": chatgpt_response, "title": conversation.title})

@extend_schema(
request=inline_serializer(name='UpdateTitleRequest', fields={
'title': drf_serializers.CharField(help_text='New conversation title'),
}),
responses={
200: inline_serializer(name='UpdateTitleResponse', fields={
'status': drf_serializers.CharField(),
'title': drf_serializers.CharField(),
}),
400: inline_serializer(name='UpdateTitleBadRequest', fields={
'error': drf_serializers.CharField(),
}),
}
)
@action(detail=True, methods=['patch'])
def update_title(self, request, pk=None):
conversation = self.get_object()
Expand Down
23 changes: 22 additions & 1 deletion server/api/views/embeddings/embeddingsView.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from rest_framework.views import APIView
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework import status
from rest_framework import status, serializers as drf_serializers
from django.http import StreamingHttpResponse
from drf_spectacular.utils import extend_schema, inline_serializer, OpenApiParameter
from ...services.embedding_services import get_closest_embeddings
from ...services.conversions_services import convert_uuids
from ...services.openai_services import openAIServices
Expand All @@ -15,6 +16,26 @@
class AskEmbeddingsAPIView(APIView):
permission_classes = [IsAuthenticated]

@extend_schema(
parameters=[
OpenApiParameter(name='guid', type=str, location=OpenApiParameter.QUERY, required=False, description='Optional file GUID to filter embeddings'),
OpenApiParameter(name='stream', type=bool, location=OpenApiParameter.QUERY, required=False, description='Enable streaming response'),
],
request=inline_serializer(name='AskEmbeddingsRequest', fields={
'message': drf_serializers.CharField(help_text='Question to ask against embedded documents'),
}),
responses={
200: inline_serializer(name='AskEmbeddingsResponse', fields={
'question': drf_serializers.CharField(),
'llm_response': drf_serializers.CharField(),
'embeddings_info': drf_serializers.CharField(),
'sent_to_llm': drf_serializers.CharField(),
}),
400: inline_serializer(name='AskEmbeddingsBadRequest', fields={
'error': drf_serializers.CharField(),
}),
}
)
def post(self, request, *args, **kwargs):
try:
user = request.user
Expand Down
1 change: 1 addition & 0 deletions server/api/views/feedback/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class FeedbackView(APIView):
permission_classes = [AllowAny]
serializer_class = FeedbackSerializer

def post(self, request, *args, **kwargs):
serializer = FeedbackSerializer(data=request.data)
Expand Down
48 changes: 47 additions & 1 deletion server/api/views/listMeds/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from rest_framework import status
from rest_framework import status, serializers as drf_serializers
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
from drf_spectacular.utils import extend_schema, inline_serializer

from .models import Diagnosis, Medication, Suggestion
from .serializers import MedicationSerializer
Expand All @@ -24,6 +25,33 @@
class GetMedication(APIView):
permission_classes = [AllowAny]

@extend_schema(
request=inline_serializer(
name='GetMedicationRequest',
fields={
'state': drf_serializers.CharField(help_text='Diagnosis state, e.g. "depressed", "manic"'),
'suicideHistory': drf_serializers.BooleanField(default=False),
'kidneyHistory': drf_serializers.BooleanField(default=False),
'liverHistory': drf_serializers.BooleanField(default=False),
'bloodPressureHistory': drf_serializers.BooleanField(default=False),
'weightGainConcern': drf_serializers.BooleanField(default=False),
'priorMedications': drf_serializers.CharField(required=False, default='', help_text='Comma-separated medication names'),
}
),
responses={
200: inline_serializer(
name='GetMedicationResponse',
fields={
'first': drf_serializers.ListField(child=drf_serializers.DictField()),
'second': drf_serializers.ListField(child=drf_serializers.DictField()),
'third': drf_serializers.ListField(child=drf_serializers.DictField()),
}
),
404: inline_serializer(name='GetMedicationNotFound', fields={
'error': drf_serializers.CharField(),
}),
}
)
def post(self, request):
data = request.data
state_query = data.get('state', '')
Expand Down Expand Up @@ -75,6 +103,7 @@ def post(self, request):

class ListOrDetailMedication(APIView):
permission_classes = [AllowAny]
serializer_class = MedicationSerializer

def get(self, request):
name_query = request.query_params.get('name', None)
Expand All @@ -98,6 +127,7 @@ class AddMedication(APIView):
"""
API endpoint to add a medication to the database with its risks and benefits.
"""
serializer_class = MedicationSerializer

def post(self, request):
data = request.data
Expand Down Expand Up @@ -129,6 +159,22 @@ class DeleteMedication(APIView):
API endpoint to delete medication if medication in database.
"""

@extend_schema(
request=inline_serializer(name='DeleteMedicationRequest', fields={
'name': drf_serializers.CharField(),
}),
responses={
200: inline_serializer(name='DeleteMedicationSuccess', fields={
'success': drf_serializers.CharField(),
}),
400: inline_serializer(name='DeleteMedicationBadRequest', fields={
'error': drf_serializers.CharField(),
}),
404: inline_serializer(name='DeleteMedicationNotFound', fields={
'error': drf_serializers.CharField(),
}),
}
)
def delete(self, request):
data = request.data
name = data.get('name', '').strip()
Expand Down
2 changes: 2 additions & 0 deletions server/api/views/medRules/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rest_framework import serializers
from drf_spectacular.utils import extend_schema_field
from ...models.model_medRule import MedRule, MedRuleSource
from ..listMeds.serializers import MedicationSerializer
from ...models.model_embeddings import Embeddings
Expand Down Expand Up @@ -30,6 +31,7 @@ class Meta:
"medication_sources",
]

@extend_schema_field(MedicationWithSourcesSerializer(many=True))
def get_medication_sources(self, obj):
medrule_sources = MedRuleSource.objects.filter(medrule=obj).select_related(
"medication", "embedding"
Expand Down
25 changes: 24 additions & 1 deletion server/api/views/medRules/views.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from rest_framework.views import APIView
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework import status
from rest_framework import status, serializers as drf_serializers
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from drf_spectacular.utils import extend_schema, inline_serializer
from ...models.model_medRule import MedRule
from .serializers import MedRuleSerializer # You'll need to create this
from ..listMeds.models import Medication
Expand All @@ -13,6 +14,7 @@
@method_decorator(csrf_exempt, name='dispatch')
class MedRules(APIView):
permission_classes = [IsAuthenticated]
serializer_class = MedRuleSerializer

def get(self, request, format=None):
# Get all med rules
Expand All @@ -29,6 +31,27 @@ def get(self, request, format=None):

return Response(data, status=status.HTTP_200_OK)

@extend_schema(
request=inline_serializer(name='MedRuleCreateRequest', fields={
'rule_type': drf_serializers.CharField(help_text='INCLUDE or EXCLUDE'),
'history_type': drf_serializers.CharField(help_text='e.g. DIAGNOSIS_DEPRESSED, DIAGNOSIS_MANIC'),
'reason': drf_serializers.CharField(),
'label': drf_serializers.CharField(),
'explanation': drf_serializers.CharField(),
'medication_names': drf_serializers.ListField(child=drf_serializers.CharField()),
'chunk_ids': drf_serializers.ListField(child=drf_serializers.IntegerField()),
'file_guid': drf_serializers.CharField(),
}),
responses={
201: MedRuleSerializer,
400: inline_serializer(name='MedRuleCreateBadRequest', fields={
'error': drf_serializers.CharField(),
}),
404: inline_serializer(name='MedRuleCreateNotFound', fields={
'error': drf_serializers.CharField(),
}),
}
)
def post(self, request):

data = request.data
Expand Down
25 changes: 24 additions & 1 deletion server/api/views/risk/views_riskWithSources.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework import status, serializers as drf_serializers
from rest_framework.permissions import AllowAny
from drf_spectacular.utils import extend_schema, inline_serializer
from api.views.listMeds.models import Medication
from api.models.model_medRule import MedRule, MedRuleSource
import openai
Expand All @@ -11,6 +12,28 @@
class RiskWithSourcesView(APIView):
permission_classes = [AllowAny]

@extend_schema(
request=inline_serializer(name='RiskWithSourcesRequest', fields={
'drug': drf_serializers.CharField(help_text='Medication name'),
'source': drf_serializers.CharField(required=False, help_text='One of: include, diagnosis, diagnosis_depressed, diagnosis_manic, diagnosis_hypomanic, diagnosis_euthymic'),
}),
responses={
200: inline_serializer(name='RiskWithSourcesResponse', fields={
'benefits': drf_serializers.ListField(child=drf_serializers.CharField()),
'risks': drf_serializers.ListField(child=drf_serializers.CharField()),
'sources': drf_serializers.ListField(child=drf_serializers.DictField()),
'medrules_found': drf_serializers.IntegerField(required=False),
'source_type': drf_serializers.CharField(required=False),
'note': drf_serializers.CharField(required=False),
}),
400: inline_serializer(name='RiskWithSourcesBadRequest', fields={
'error': drf_serializers.CharField(),
}),
404: inline_serializer(name='RiskWithSourcesNotFound', fields={
'error': drf_serializers.CharField(),
}),
}
)
def post(self, request):
openai.api_key = os.environ.get("OPENAI_API_KEY")

Expand Down
29 changes: 29 additions & 0 deletions server/api/views/text_extraction/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
import anthropic
from drf_spectacular.utils import extend_schema, inline_serializer, OpenApiParameter
from rest_framework import serializers as drf_serializers

from ...services.openai_services import openAIServices
from api.models.model_embeddings import Embeddings
Expand Down Expand Up @@ -97,6 +99,20 @@ class RuleExtractionAPIView(APIView):

permission_classes = [IsAuthenticated]

@extend_schema(
parameters=[
OpenApiParameter(name='guid', type=str, location=OpenApiParameter.QUERY, required=True, description='File GUID to extract rules from'),
],
responses={
200: inline_serializer(name='RuleExtractionResponse', fields={
'texts': drf_serializers.CharField(),
'cited_texts': drf_serializers.CharField(),
}),
500: inline_serializer(name='RuleExtractionError', fields={
'error': drf_serializers.CharField(),
}),
}
)
def get(self, request):
try:

Expand Down Expand Up @@ -141,6 +157,19 @@ def openai_extraction(content_chunks, user_prompt):
class RuleExtractionAPIOpenAIView(APIView):
permission_classes = [IsAuthenticated]

@extend_schema(
parameters=[
OpenApiParameter(name='guid', type=str, location=OpenApiParameter.QUERY, required=True, description='File GUID to extract rules from'),
],
responses={
200: inline_serializer(name='RuleExtractionOpenAIResponse', fields={
'rules': drf_serializers.ListField(child=drf_serializers.DictField()),
}),
500: inline_serializer(name='RuleExtractionOpenAIError', fields={
'error': drf_serializers.CharField(),
}),
}
)
def get(self, request):
try:
user_prompt = """
Expand Down
Loading
Loading