# By: Riasat Ullah
# This file contains all Amazon services integration related views.

from data_syncers import syncer_services, syncer_task_instances
from dbqueries import db_integrations, db_services
from exceptions.user_exceptions import InvalidRequest, UnauthorizedRequest
from integrations import amazon
from modules.router import Router
from objects.events import ResolveEvent
from objects.task_payload import TaskPayload
from rest_framework.decorators import api_view, parser_classes
from rest_framework.response import Response
from translators import label_translator as _lt
from utils import constants, errors, helpers, info, key_manager, logging, times, var_names
from utils.custom_parsers import PlainTextParser
from utils.db_connection import CACHE_CLIENT, CONN_POOL
from validations import request_validator
import configuration
import json
import jwt


@api_view(['POST'])
@parser_classes((PlainTextParser,))
def process_sns_message(request, integration_key, conn=None, cache=None):
    '''
    Processes an SNS message received from Amazon.
    :param request: Http request
    :param integration_key: integration key passed in the url
    :param conn: db connection
    :param cache: cache client
    :return: Http response
    '''
    if request.method == 'POST':
        lang = request_validator.get_user_language(request)

        try:
            conn = CONN_POOL.get_db_conn() if conn is None else conn
            cache = CACHE_CLIENT if cache is None else cache

            unmasked_integ_key = key_manager.unmask_reference_key(integration_key)
            current_time = times.get_current_timestamp()
            integ_id, serv_id, org_id = syncer_services.get_integration_key_details(
                conn, cache, current_time, unmasked_integ_key)

            if amazon.var_message_type in request.data:
                msg_type = request.data[amazon.var_message_type]
            else:
                msg_type = amazon.msg_type_notification

            if msg_type == amazon.msg_type_notification:
                if amazon.var_message in request.data:
                    # The value of Message is sent as a string. Read it into a dictionary.
                    message = json.loads(request.data[amazon.var_message])
                    new_trig_info = request.data
                else:
                    message = request.data
                    new_trig_info = {amazon.var_message: message}

                title = message[amazon.var_alarm_name]
                notification_state = message[amazon.var_new_state_value]
                alarm_arn = message[amazon.var_alarm_arn]
                state_reason = message[amazon.var_new_state_reason] + '\n\n' + str(message)

                integ_insts = db_integrations.get_integration_open_instances_trigger_info(
                    conn, current_time, org_id, serv_id, integ_id)

                match_count = 0
                for inst_id, task_id, trig_info in integ_insts:
                    trig_alarm_arn = json.loads(
                        trig_info[var_names.source_payload][amazon.var_message])[amazon.var_alarm_arn]
                    if trig_info is not None and trig_alarm_arn == alarm_arn:
                        match_count += 1
                        if notification_state == amazon.msg_state_ok:
                            event = ResolveEvent(inst_id, current_time, constants.integrations_api)
                            syncer_task_instances.resolve(conn, cache, event, org_id, is_sys_action=True)
                        else:
                            # Group the alert if one already exists
                            payload = TaskPayload(
                                current_time, org_id, current_time.date(), title, configuration.standard_timezone,
                                current_time.time(), text_msg=state_reason, urgency_level=constants.high_urgency,
                                trigger_method=constants.integrations_api, trigger_info=request.data,
                                integration_id=integ_id, integration_key=integration_key, service_id=serv_id,
                                instantiate=False, alert=False, related_task_id=task_id,
                                task_status=constants.grouped_state
                            )
                            Router(conn, cache, payload).start()

                if match_count == 0:
                    payload = TaskPayload(
                        current_time, org_id, current_time.date(), title, configuration.standard_timezone,
                        current_time.time(), text_msg=state_reason, urgency_level=constants.high_urgency,
                        trigger_method=constants.integrations_api, trigger_info=new_trig_info,
                        integration_id=integ_id, integration_key=integration_key, service_id=serv_id
                    )
                    Router(conn, cache, payload).start()

                return Response(info.msg_internal_success)

            elif msg_type == amazon.msg_type_subscription_confirmation:
                subscription_url = request.data[amazon.var_subscription_url]
                status, output = helpers.post_api_request(subscription_url, dict())
                if status == 200:
                    sub_arn = output[amazon.var_confirm_subscription_response][
                        amazon.var_confirm_subscription_result][amazon.var_subscription_arn]
                    db_services.edit_service_integration_vendor_details(conn, current_time, org_id, unmasked_integ_key,
                                                                        vendor_endpoint=sub_arn)
                    return Response('Subscription has been confirmed')
                else:
                    logging.error(output)
                    return Response('Subscription could not be confirmed', status=400)

            elif msg_type == amazon.msg_type_unsubscribe_confirmation_msg:
                return Response('Integration has been un-subscribed')

            else:
                return Response('Unknown message type provided', status=400)
        except InvalidRequest as e:
            logging.exception(str(e))
            return Response(_lt.get_label(str(e), lang), status=400)
        except (UnauthorizedRequest, jwt.ExpiredSignatureError, jwt.InvalidSignatureError) as e:
            logging.exception(str(e))
            return Response(_lt.get_label(errors.err_authorization, lang), status=401)
        except Exception as e:
            logging.exception(str(e))
            return Response(_lt.get_label(errors.err_system_error, lang), status=500)
        finally:
            CONN_POOL.put_db_conn(conn)
