@@ -29,12 +29,12 @@ def name(self) -> str:
2929 """
3030 return "codegate-context-retriever"
3131
32- async def get_objects_from_search (
33- self , search : str , ecosystem , packages : list [str ] = None
32+ async def get_objects_from_db (
33+ self , ecosystem , packages : list [str ] = None
3434 ) -> list [object ]:
3535 storage_engine = StorageEngine ()
3636 objects = await storage_engine .search (
37- search , distance = 0.8 , ecosystem = ecosystem , packages = packages
37+ distance = 0.8 , ecosystem = ecosystem , packages = packages
3838 )
3939 return objects
4040
@@ -103,39 +103,25 @@ async def process(
103103 # Extract packages from the user message
104104 ecosystem = await self .__lookup_ecosystem (user_messages , context )
105105 packages = await self .__lookup_packages (user_messages , context )
106- packages = [pkg .lower () for pkg in packages ]
107106
108- # If user message does not reference any packages, then just return
109- if len (packages ) == 0 :
110- return PipelineResult (request = request )
111-
112- # Look for matches in vector DB using list of packages as filter
113- searched_objects = await self .get_objects_from_search (user_messages , ecosystem , packages )
107+ context_str = "CodeGate did not find any malicious or archived packages."
114108
115- logger .info (
116- f"Found { len (searched_objects )} matches in the database" ,
117- searched_objects = searched_objects ,
118- )
109+ if len (packages ) > 0 :
110+ # Look for matches in DB using packages and ecosystem
111+ searched_objects = await self .get_objects_from_db (ecosystem , packages )
119112
120- # Remove searched objects that are not in packages. This is needed
121- # since Weaviate performs substring match in the filter.
122- updated_searched_objects = []
123- for searched_object in searched_objects :
124- if searched_object .properties ["name" ].lower () in packages :
125- updated_searched_objects .append (searched_object )
126- searched_objects = updated_searched_objects
113+ logger .info (
114+ f"Found { len (searched_objects )} matches in the database" ,
115+ searched_objects = searched_objects ,
116+ )
127117
128- # Generate context string using the searched objects
129- logger .info (f"Adding { len (searched_objects )} packages to the context" )
118+ # Generate context string using the searched objects
119+ logger .info (f"Adding { len (searched_objects )} packages to the context" )
130120
131- if len (searched_objects ) > 0 :
132- context_str = self .generate_context_str (searched_objects , context )
133- else :
134- context_str = "CodeGate did not find any malicious or archived packages."
121+ if len (searched_objects ) > 0 :
122+ context_str = self .generate_context_str (searched_objects , context )
135123
136124 last_user_idx = self .get_last_user_message_idx (request )
137- if last_user_idx == - 1 :
138- return PipelineResult (request = request , context = context )
139125
140126 # Make a copy of the request
141127 new_request = request .copy ()
0 commit comments