@@ -62,7 +62,11 @@ async def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = N
6262 table = table if table is not None else self .table
6363 return_keys = return_keys or []
6464 try :
65- insert_stmt = insert (table ).values (data ).returning (* (getattr (table .c , key ) for key in return_keys ))
65+ insert_stmt = (
66+ insert (table )
67+ .values (data )
68+ .returning (* (getattr (table .c if isinstance (table , Table ) else table , key ) for key in return_keys ))
69+ )
6670 return_values = await self .session .execute (insert_stmt )
6771 await self .session .commit ()
6872 if return_keys :
@@ -93,7 +97,7 @@ async def update_with_where(
9397 update (table )
9498 .values (data )
9599 .where (* where_conditions )
96- .returning (* (getattr (table .c , key ) for key in return_keys ))
100+ .returning (* (getattr (table .c if isinstance ( table , Table ) else table , key ) for key in return_keys ))
97101 )
98102 return_values = await self .session .execute (update_stmt )
99103 await self .session .commit ()
@@ -115,7 +119,10 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
115119 return_keys = return_keys or []
116120 try :
117121 return_values = await self .session .execute (
118- update (table ).returning (* (getattr (table .c , key ) for key in return_keys )), data
122+ update (table ).returning (
123+ * (getattr (table .c if isinstance (table , Table ) else table , key ) for key in return_keys )
124+ ),
125+ data ,
119126 )
120127 await self .session .commit ()
121128 if return_keys :
@@ -124,24 +131,34 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
124131 logger .error (f"Error occurred while updating: { e } " , exc_info = True )
125132 raise e
126133
127- async def upsert (self , insert_json : dict , primary_keys : List [str ] = None , table : TableType = None ):
134+ async def upsert (
135+ self , insert_json : dict , primary_keys : List [str ] = None , return_keys : List [str ] = None , table : TableType = None
136+ ):
128137 """
129138 Inserts or updates a row in the database.
130139
131140 Args:
132141 insert_json (dict): A dictionary containing the data to be inserted or updated.
133142 primary_keys (List[str], optional): A list of primary key column names. Defaults to None.
143+ return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None.
134144 table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None.
145+
146+ Returns:
147+ A list of dictionaries containing the upserted data if return_keys is provided.
135148 """
136149 table = table if table is not None else self .table
150+ return_keys = return_keys or []
137151 try :
138152 insert_statement = (
139153 postgres_insert (table )
140154 .values (** insert_json )
141155 .on_conflict_do_update (index_elements = primary_keys , set_ = insert_json )
156+ .returning (* (getattr (table .c if isinstance (table , Table ) else table , key ) for key in return_keys ))
142157 )
143- await self .session .execute (insert_statement )
158+ return_values = await self .session .execute (insert_statement )
144159 await self .session .commit ()
160+ if return_keys :
161+ return jsonable_encoder (return_values .mappings ().all ())
145162 except Exception as e :
146163 logger .error (f"Error while upserting the record { e } " , exc_info = True )
147164 raise e
0 commit comments