Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 40 additions & 42 deletions igf_data/igfdb/dbconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,22 @@ class DBConnect:
A class for managing dbconnection
'''
def __init__(self, **data):
data.setdefault('dbhost', '')
data.setdefault('dbport', '')
data.setdefault('dbuser', '')
data.setdefault('dbpass', '')
data.setdefault('dbname', '')
data.setdefault('url', '')
data.setdefault('driver', 'sqlite')
data.setdefault('connector', '')
data.setdefault('supported_drivers', ('mysql', 'sqlite'))
data.setdefault('engine_config', {})

self.dbhost = data['dbhost']
self.dbport = data['dbport']
self.dbuser = data['dbuser']
self.dbpass = data['dbpass']
self.dbname = data['dbname']
self.driver = data['driver']
self.connector = data['connector']
self.supported_drivers = data['supported_drivers']
self.engine_config = data['engine_config']
self.dbhost = data.get('dbhost', '')
self.dbport = data.get('dbport', '')
self.dbuser = data.get('dbuser', '')
self.dbpass = data.get('dbpass', '')
self.dbname = data.get('dbname', '')
self.driver = data.get('driver', 'sqlite')
self.connector = data.get('connector', '')
self.supported_drivers = data.get('supported_drivers', ('mysql', 'sqlite'))
self.engine_config = data.get('engine_config', {})
self.connect_args = data.get('connect_args', {})
# create engine and configure session at start up
if data['url'] == '':
data_url = data.get('url', '')
if data_url == '':
self.dburl = self._prepare_db_url() # get dburl for connection
else:
self.dburl = data['url']
self.dburl = data_url
self.engine = self._create_session_engine() # get engine connection
self.session_class = self._configure_session() # get session class

Expand All @@ -48,20 +39,22 @@ def _prepare_db_url(self):
connector = self.connector

if driver not in self.supported_drivers:
raise ValueError('Database driver {0} is not supported yet.'.format(driver)) # check for supported databases
raise ValueError(
f'Database driver {driver} is not supported yet.') # check for supported databases

if driver != 'sqlite' and (not dbuser or not dbhost):
raise ValueError('driver {0} require dbuser and dbhost details, {1},{2}'.format(driver,dbuser,dbhost)) # check for required parameters
raise ValueError(
f'Driver {driver} require dbuser and dbhost details, {dbuser},{dbhost}') # check for required parameters

dburl='{0}'.format(driver)
dburl = driver
if connector:
dburl='{0}+{1}'.format(dburl, connector)
dburl='{0}://'.format(dburl)
dburl = f'{dburl}+{connector}'
dburl = f'{dburl}://'
if dbuser and dbpass and dbhost:
dburl='{0}{1}:{2}@{3}'.format(dburl, dbuser, dbpass, dbhost)
dburl = f'{dburl}{dbuser}:{dbpass}@{dbhost}'
if dbport:
dburl='{0}:{1}'.format(dburl, dbport)
dburl='{0}/{1}'.format(dburl, dbname)
dburl = f'{dburl}:{dbport}'
dburl = f'{dburl}/{dbname}'
return dburl


Expand All @@ -70,12 +63,16 @@ def _create_session_engine(self):
An internal method for creating an database engine required for the session
'''
if not hasattr(self, 'dburl'):
raise AttributeError('Attribute dburl not defined')

raise AttributeError(
'Attribute dburl not defined')
try:
engine = create_engine(self.dburl, **self.engine_config ) # create engine with additional parameter
engine = \
create_engine(
self.dburl,
connect_args=self.connect_args,
**self.engine_config ) # create engine with additional parameter
return engine
except:
except Exception:
raise


Expand All @@ -87,9 +84,10 @@ def _configure_session(self):
raise AttributeError('Attribute engine not defined')

try:
session_class=sessionmaker(bind=self.engine) # create session class
session_class = \
sessionmaker(bind=self.engine) # create session class
return session_class
except:
except Exception:
raise


Expand All @@ -100,7 +98,7 @@ def get_session_class(self):
if not hasattr(self, 'session_class'):
raise AttributeError('Attribute session_class not defined')

session_class=self.session_class
session_class = self.session_class
return session_class


Expand All @@ -117,7 +115,7 @@ def start_session(self):
Session=self.session_class
try:
self.session=Session() # create session
except:
except Exception:
raise


Expand All @@ -130,7 +128,7 @@ def commit_session(self):

try:
self.session.commit() # commit session
except:
except Exception:
raise

def rollback_session(self):
Expand All @@ -139,7 +137,7 @@ def rollback_session(self):
'''
try:
self.session.rollback()
except:
except Exception:
raise

def close_session(self, save_changes=False):
Expand All @@ -151,14 +149,14 @@ def close_session(self, save_changes=False):
if not hasattr(self, 'session'):
raise AttributeError('Attribute session not defined')

session=self.session
session = self.session
try:
if save_changes:
session.commit() # commit session
session.close() # close session
self.session=None # set the session attribute to None
del self.session # delete session attribute
except:
except Exception:
raise
if hasattr(self, 'session'):
raise AttributeError('Attribute session not deleted yet')
Expand Down