# By: Riasat Ullah
# This class manages the database connections.

from psycopg2 import pool, extras
from taskcallrest import settings
from utils import constants, s3
import configuration as configs
import psycopg2
import redis


str_bucket = 'bucket'
str_db_name = 'db_name'
str_host = 'host'
str_key = 'key'
str_password = 'password'
str_port = 'port'
str_user = 'user'

allowed_dbs = {constants.prod_db_europe: {str_bucket: 'taskcall-prod-data', str_key: 'credentials/db_creds.json'},
               constants.redis_cache_europe: {str_bucket: 'taskcall-prod-data', str_key: 'credentials/db_creds.json'},
               constants.test_db_europe: {str_bucket: 'taskcall-prod-data', str_key: 'credentials/db_creds.json'},
               constants.prod_db_us: {str_bucket: 'taskcall-prod-data', str_key: 'credentials/db_creds.json'},
               constants.redis_cache_us: {str_bucket: 'taskcall-prod-data', str_key: 'credentials/db_creds.json'},
               constants.test_db_us: {str_bucket: 'taskcall-prod-data', str_key: 'credentials/db_creds.json'}}

psycopg2.extras.register_uuid()


class DBConn(object):

    def __init__(self, db_name=None, conn=None):
        self.db_name = constants.prod_db_europe if db_name is None else db_name
        self.conn = self.connect(self.db_name) if conn is None else conn

    @staticmethod
    def connect(db_name):
        '''
        Creates a database connection
        :param db_name: name of the database
        :return: a database connection
        '''
        try:
            db_location = allowed_dbs[db_name]
            data = s3.read_json(db_location[str_bucket], db_location[str_key])

            db_details = data[db_name]
            db_name = db_details[str_db_name]
            user = db_details[str_user]
            host = db_details[str_host]
            password = db_details[str_password]
            port = db_details[str_port]

            conn = psycopg2.connect("dbname='{0}' user='{1}' host='{2}' password='{3}' port='{4}'".format(
                db_name, user, host, password, port
            ))
            return conn
        except KeyError as e:
            err = 'Unknown database - ' + db_name + '\n' + str(e)
            raise KeyError(err)
        except (OSError, IOError) as e:
            err = 'Could not read db credentials file' + '\n' + str(e)
            raise OSError(err)
        except psycopg2.DatabaseError:
            raise

    def check_connection_alive(self):
        '''
        Checks if the connection is alive. If it is not, then reconnects to the db.
        '''
        if self.conn.closed != 0:
            self.conn = self.connect(get_environment_db_name())

    def execute(self, query, params=None):
        '''
        Executes queries that make changes to the database - insert, update, etc
        :param query: the query to execute
        :param params: parameters to pass in as literals to avoid sql injection
        '''
        self.check_connection_alive()
        cur = self.conn.cursor()
        cur.execute(query, params)
        self.conn.commit()
        cur.close()

    def fetch(self, query, params=None):
        '''
        For fetching rows from the database only --> only select queries
        :param query: query to execute
        :param params: parameters to pass in as literals to avoid sql injection
        :return: fetched rows
        '''
        self.check_connection_alive()
        cur = self.conn.cursor()
        cur.execute(query, params)
        self.conn.commit()
        results = cur.fetchall()
        cur.close()
        return results

    def disconnect(self):
        '''
        Closes the database connection
        '''
        self.conn.close()

    def rollback(self):
        '''
        Rollback the changes made in an uncommitted transaction.
        '''
        self.conn.rollback()

    def execute_batch(self, query, params_list):
        '''
        Execute the same statement with a batch of different values.
        :param query: (str) query to execute
        :param params_list: (list) of tuples of values
        '''
        self.check_connection_alive()
        cur = self.conn.cursor()
        extras.execute_batch(cur, query, params_list)
        self.conn.commit()
        cur.close()

    def get_cursor(self):
        return self.conn.cursor()

    @staticmethod
    def close_cursor(cur):
        cur.close()

    @staticmethod
    def execute_with_cursor(cur, query, params=None):
        cur.execute(query, params)

    def commit(self):
        self.conn.commit()


class TestTasksDBConn(DBConn):
    '''
    DBConn object for the testing the database.
    '''
    def __init__(self, conn=None):
        conn = psycopg2.connect("dbname='{0}' user='{1}' host='{2}' password='{3}' port='{4}'".format(
            'postgres', 'postgres', 'localhost', '', '5433')) if conn is None else conn
        DBConn.__init__(self, conn=conn)

    def execute(self, query, params=None):
        '''
        Uncommitted execute.
        '''
        cur = self.conn.cursor()
        cur.execute(query, params)
        cur.close()

    def fetch(self, query, params=None):
        '''
        Uncommitted fetch.
        '''
        cur = self.conn.cursor()
        cur.execute(query, params)
        results = cur.fetchall()
        cur.close()
        return results

    def execute_batch(self, query, params_list):
        '''
        Execute the same statement with a batch of different values.
        :param query: (str) query to execute
        :param params_list: (list) of tuples of values
        '''
        cur = self.conn.cursor()
        extras.execute_batch(cur, query, params_list)
        cur.close()

    def execute_and_commit(self, query, params=None):
        DBConn.execute(self, query, params)

    def fetch_and_commit(self, query, params=None):
        DBConn.fetch(self, query, params)

    def disconnect(self):
        '''
        Closes the database connection
        '''
        self.conn.close()


class ConnPool(object):

    def __init__(self, db_name=None, conn_pool=None, max_conn=configs.pool_max_connections):
        self.db_name = constants.prod_db_europe if db_name is None else db_name
        self.conn_pool = self.create_pool(self.db_name, configs.pool_min_connections, max_conn)\
            if conn_pool is None else conn_pool

    @staticmethod
    def create_pool(db_name, min_connections, max_connections):
        '''
        Creates a connection pool.
        :param db_name: name of the database
        :param min_connections: minimum number of connections to start off the pool with
        :param max_connections: maximum number of connections allowed in the pool
        :return: a psycopg2.pool.ThreadedConnection object
        '''
        try:
            if db_name == 'local':
                db_name = 'postgres'
                user = 'postgres'
                host = 'localhost'
                password = ''
                port = '5433'

            else:
                db_location = allowed_dbs[db_name]
                data = s3.read_json(db_location[str_bucket], db_location[str_key])

                db_details = data[db_name]
                db_name = db_details[str_db_name]
                user = db_details[str_user]
                host = db_details[str_host]
                password = db_details[str_password]
                port = db_details[str_port]

            conn_pool = pool.ThreadedConnectionPool(
                min_connections, max_connections, user=user, password=password, host=host, port=port, database=db_name
            )
            return conn_pool
        except KeyError as e:
            err = 'Unknown database - ' + db_name + '\n' + str(e)
            raise KeyError(err)
        except (OSError, IOError) as e:
            err = 'Could not read db credentials file' + '\n' + str(e)
            raise OSError(err)
        except psycopg2.DatabaseError:
            raise

    def get_db_conn(self):
        '''
        Get a DBConn object from the connection pool.
        '''
        return DBConn(conn=self.conn_pool.getconn())

    def put_db_conn(self, db_conn: DBConn):
        '''
        Puts the open connection of the DBConn object back in the pool and deletes the DBConn object
        :param db_conn: DBConn object
        '''
        self.conn_pool.putconn(db_conn.conn)
        del db_conn


class RedisClient(object):

    @staticmethod
    def create_client(cache_name=constants.redis_cache_europe, max_conn=configs.cache_max_connections):
        '''
        Creates a cache connection.
        :param cache_name: name of the cache
        :param max_conn: the maximum number of connections to create in the connection pool
        :return: a Redis client
        '''
        try:
            if cache_name == 'local':
                host = 'localhost'
                port = '6379'
            else:
                cache_location = allowed_dbs[cache_name]
                data = s3.read_json(cache_location[str_bucket], cache_location[str_key])

                cache_details = data[cache_name]
                host = cache_details[str_host]
                port = cache_details[str_port]

            conn_pool = redis.BlockingConnectionPool(max_connections=max_conn, host=host, port=port,
                                                     decode_responses=True)
            client = redis.Redis(connection_pool=conn_pool)
            return client
        except KeyError as e:
            err = 'Unknown cache - ' + cache_name + '\n' + str(e)
            raise KeyError(err)
        except (OSError, IOError) as e:
            err = 'Could not read db credentials file' + '\n' + str(e)
            raise OSError(err)
        except psycopg2.DatabaseError:
            raise


def get_environment_db_name():
    '''
    Get the name of the db name for the given environment.
    :return: (str) name of the database
    '''
    if settings.REGION == constants.aws_europe_paris:
        if settings.TEST_SERVER:
            return constants.test_db_europe
        else:
            return constants.prod_db_europe
    elif settings.REGION == constants.aws_us_ohio:
        if settings.TEST_SERVER:
            return constants.test_db_us
        else:
            return constants.prod_db_us


# Create a global ConnPool object.
# This will be used across the project to handle DB connections.

if settings.INITIALIZE_GLOBAL_VARIABLES and not settings.TEST_MODE:
    pool_mx = configs.mail_pool_max_connections if settings.MAIL_SERVER else configs.pool_max_connections
    cache_mx = configs.mail_cache_max_connections if settings.MAIL_SERVER else configs.cache_max_connections

    if settings.REGION == constants.aws_europe_paris:
        if settings.TEST_SERVER:
            CONN_POOL = ConnPool(db_name=get_environment_db_name(), max_conn=pool_mx)
            CONN_CENTRAL_POOL = ConnPool(db_name=get_environment_db_name(), max_conn=1)
            CACHE_CLIENT = RedisClient.create_client(max_conn=cache_mx)
        else:
            CONN_POOL = ConnPool(max_conn=pool_mx)
            CONN_CENTRAL_POOL = ConnPool(max_conn=pool_mx)
            CACHE_CLIENT = RedisClient.create_client(max_conn=cache_mx)

    elif settings.REGION == constants.aws_us_ohio:
        if settings.TEST_SERVER:
            CONN_POOL = ConnPool(db_name=get_environment_db_name(), max_conn=pool_mx)
            CONN_CENTRAL_POOL = ConnPool(db_name=get_environment_db_name(), max_conn=1)
            CACHE_CLIENT = RedisClient.create_client(constants.redis_cache_us, max_conn=cache_mx)
        else:
            CONN_POOL = ConnPool(db_name=get_environment_db_name(), max_conn=pool_mx)
            CONN_CENTRAL_POOL = ConnPool(max_conn=pool_mx)
            CACHE_CLIENT = RedisClient.create_client(constants.redis_cache_us, max_conn=cache_mx)

elif settings.INITIALIZE_GLOBAL_VARIABLES and settings.TEST_MODE:
    CONN_POOL = ConnPool(db_name='local')
    CONN_CENTRAL_POOL = ConnPool(db_name='local')
    CACHE_CLIENT = RedisClient.create_client('local')
else:
    CONN_POOL = ConnPool('mock', 'mock')
    CONN_CENTRAL_POOL = ConnPool('mock', 'mock')
    CACHE_CLIENT = None
