Last active
January 30, 2023 14:24
-
-
Save OzTamir/a45aafc9a0a53d3a085b to your computer and use it in GitHub Desktop.
A simple MySQL wrapper (and UI utils!) in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import print_function | |
| import mysql.connector | |
| from mysql.connector import errorcode | |
| from mysql.connector.errors import * | |
| import sys | |
| class Database(object): | |
| def __init__(self, dbhost, dbuser, dbpass, dbname, debug=False): | |
| ''' Intialize a Database object and connect to the database ''' | |
| self.host = dbhost | |
| self.name = dbname | |
| self.debug = debug | |
| self.conn = None | |
| # Try to connect to the database, and if there were any errors report and quit | |
| try: | |
| self.conn = self.get_connector(dbuser, dbpass) | |
| except mysql.connector.Error as err: | |
| print(str(err)) | |
| print('Error connecting to the host.') | |
| sys.exit(1) | |
| if self.debug: | |
| print('Connected to host.') | |
| # Get the cursor | |
| self.cursor = self.conn.cursor() | |
| def get_connector(self, username, dbpass): | |
| ''' Connect to the database and return the connector ''' | |
| if self.debug: | |
| print('Connecting to database...') | |
| # If there is an active connection ,return it | |
| if isinstance(self.conn, mysql.connector.connection.MySQLConnection): | |
| if self.debug: | |
| print('Already connected!') | |
| return self.conn | |
| return mysql.connector.connect(user = username, password = dbpass, host = self.host, database = self.name) | |
| def __get_results(self, query, err_msg='get_results', data=None): | |
| ''' Return all the results from a query ''' | |
| try: | |
| if data: | |
| self.cursor.execute(query, data) | |
| else: | |
| self.cursor.execute(query) | |
| except InternalError: | |
| if self.debug: | |
| print('Error in %s: No results for table %s' % (str(err_msg), str(table))) | |
| return [] | |
| return self.cursor.fetchall() | |
| def __iter_results(self, query, err_msg='__iter_results', data=None): | |
| ''' Return a generetor that yields results from a query ''' | |
| try: | |
| if data: | |
| self.cursor.execute(query, data) | |
| else: | |
| self.cursor.execute(query) | |
| except InternalError: | |
| if self.debug: | |
| print('Error in %s: No results for table %s' % (str(err_msg), str(table))) | |
| raise StopIteration | |
| for entry in self.cursor: | |
| yield entry | |
| def get_entries(self, table): | |
| ''' Return all the entries from a given table in the database ''' | |
| return self.__get_results('SELECT * FROM %s' % str(table), 'get_entries') | |
| def iter_entries(self, table): | |
| ''' Return a generetor that yields entries from a given table in the database ''' | |
| return self.__iter_results('SELECT * FROM %s' % str(table), 'iter_entries') | |
| def get_column_names(self, table): | |
| ''' Wrapper for get_columns, only return the names and not types ''' | |
| return [col[0] for col in self.get_columns(table)] | |
| def get_columns(self, table): | |
| ''' Get the column names in a given table in the database ''' | |
| self.cursor.execute('SELECT * FROM %s' % str(table)) | |
| # Return the column names and types from the cursor description {n : t} | |
| columns = [(i[0], i[1]) for i in self.cursor.description] | |
| # Clear the buffer to avoid 'Unread result' | |
| self.clear_cursor() | |
| # Return the result | |
| return columns | |
| def close_connection(self): | |
| ''' Close the connection ''' | |
| self.conn.close() | |
| if self.debug: | |
| print('Connection closed.') | |
| def clear_cursor(self): | |
| ''' Clear the cursor if we don't need results (used in get_columns) ''' | |
| if self.debug: | |
| print('Clearing cursor...') | |
| self.cursor.fetchall() | |
| def commit(self): | |
| ''' Commit changes to the remote DB ''' | |
| if self.debug: | |
| print('Commiting changes...') | |
| self.conn.commit() | |
| def rollback(self): | |
| ''' Rollback changes in case of errors of any kind ''' | |
| if self.debug: | |
| print('Rolling back...') | |
| self.conn.rollback() | |
| def insert(self, table, columns, values): | |
| ''' Insert a new entry into a table with given values ''' | |
| # Get the columns's names | |
| columns = str(tuple([str(x) for x in columns])) | |
| # Create the query statment | |
| query = 'INSERT INTO %s %s' % (table, columns.replace("'", '')) | |
| values_query = 'VALUES (' + ('%s, ' * len(values))[:-2] + ')' | |
| query_stmt = ' '.join([ | |
| query, | |
| values_query | |
| ]) | |
| try: | |
| self.cursor.execute(query_stmt, values) | |
| # Commit the changes to the remote DB | |
| self.commit() | |
| except Exception, e: | |
| if self.debug: | |
| print('Error: %s' % str(e)) | |
| # Rollback the changes from the current transaction | |
| self.rollback() | |
| raise ValueError("Can't add entry, please try again (maybe with different values?)") | |
| def search(self, table, column, value, partial=False, case_sensetive=True): | |
| ''' Search for value in table ''' | |
| select_stmt = 'SELECT * FROM %s WHERE' % str(table) | |
| # If we want that partial match will suffice | |
| if partial: | |
| sql_function = 'LIKE' | |
| value = '%%%s%%' % str(value) | |
| else: | |
| sql_function = '=' | |
| # If we want the search to be case sensetive | |
| if case_sensetive: | |
| condition = '''`%s` %s "%s"''' % (str(column), sql_function, str(value)) | |
| else: | |
| condition = '''LOWER(`%s`) %s LOWER("%s")''' % (str(column), sql_function, str(value)) | |
| # Build to query from it's parts | |
| query = ' '.join([select_stmt, condition]) | |
| query = query.replace("'", '') | |
| return self.__iter_results(query, 'search') | |
| def __del__(self): | |
| ''' Called upon object deletion, make sure the connection to the DB is closed ''' | |
| if self.conn is not None: | |
| self.close_connection() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import datetime | |
| class DatabaseUIBase(object): | |
| ''' Base class for ex5 UIs ''' | |
| def __init__(self, db, msg): | |
| ''' Create a UI object and call the mainUI ''' | |
| self.db = db | |
| print(msg) | |
| print('-' * 10) | |
| self.mainUI() | |
| def __add_entry(self, table): | |
| ''' Add an entry to the table ''' | |
| # Get the column names | |
| columns = self.db.get_columns(table) | |
| values = ['' for col in columns] | |
| print('Please enter the following details:') | |
| for index, column in enumerate(columns): | |
| # If the column type is a date | |
| if column[1] == 10: | |
| print('Please enter date in the following format: DD.MM.YYYY') | |
| values[index] = raw_input('%s: ' % str(column[0])) | |
| # If the column type is a date | |
| if column[1] == 10: | |
| try: | |
| # Try to parse values and create a datetime object | |
| day, mnt, year = values[index].split('.') | |
| values[index] = datetime.datetime(int(year), int(mnt), int(day)) | |
| except (TypeError, ValueError): | |
| # If the user didn't follow the format, present an erro an send default values | |
| print("Error in date format, inserting '1'.'1'.'1' as date") | |
| values[index] = datetime.datetime(1, 1, 1) | |
| try: | |
| self.db.insert(table, [i[0] for i in columns], values) | |
| except ValueError, e: | |
| print('Error: %s' % str(e)) | |
| return | |
| print('Inserted the follwoing entry into %s:' % str(table)) | |
| print(' | '.join([str(val) for val in values])) | |
| def __print_columns(self, table): | |
| ''' Print the columns of a table ''' | |
| columns = ['#'] + self.db.get_column_names(table) | |
| print(' | '.join(columns)) | |
| def __print_results(self, table, enumerator): | |
| ''' Print a query result in a nice form ''' | |
| # Print the columns of the table | |
| self.__print_columns(table) | |
| cnt = 0 | |
| # Print each entry in a formatted line | |
| for index, entry in enumerator: | |
| entry = [str(val) for val in entry] | |
| print(('%d | ' % (index + 1)) + ' | '.join(entry)) | |
| cnt += 1 | |
| # Display the amount of entries found | |
| print('%d results found.' % cnt) | |
| def __print_table(self, table): | |
| ''' Print a table ''' | |
| table_enum = enumerate(self.db.iter_entries(table)) | |
| self.__print_results(table, table_enum) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Also, I might need to change the
Database.get_connector()function and make it check that the connection is active (usingMySQLConnection.is_connected()) and if not, attempt to reconnect (usingMySQLConnection.reconnect(attempts=1, delay=0)).