# By: Riasat Ullah
# This class represents a routine.

from objects.routine_layer import RoutineLayer
from utils import helpers, times, var_names
import datetime


class Routine(object):

    def __init__(self, routine_id, organization_id, routine_name, routine_timezone, routine_layers, reference_id=None,
                 associated_policies=None):
        self.routine_id = routine_id
        self.organization_id = organization_id
        self.routine_name = routine_name
        self.routine_timezone = routine_timezone
        self.routine_layers = routine_layers
        self.reference_id = reference_id
        self.associated_policies = associated_policies

    @ staticmethod
    def create_routine(details, for_display=False):
        '''
        Creates a new Routine object.
        :param details: (dict) dict of policy routines info
        :param for_display: (boolean) True if the Policy(s)/Routine(s) should be mapped on to their reference IDs
        :return: PolicyRoutine object
        '''
        routine_id = details[var_names.routine_id]
        routine_timezone = details[var_names.timezone]
        if for_display:
            routine_id = details[var_names.routine_ref_id]

        layers = []
        for item in details[var_names.routine_layers]:
            layers.append(RoutineLayer.create_layer(item, for_display))

        return Routine(routine_id, details[var_names.organization_id], details[var_names.routine_name],
                       routine_timezone, layers,
                       details[var_names.routine_ref_id] if var_names.routine_ref_id in details else None,
                       details[var_names.associated_policies] if var_names.associated_policies in details else None)

    def get_on_call(self, check_datetime=None):
        '''
        Gets the list of users from the routine who are on-call on given date and time.
        :param check_datetime: (datetime.datetime) date on which to check for; the function assumes it is in UTC
        :return: (list of tuples) -> [(user_id, display name, assignee policy id), ...] of users who are on call
        '''
        # we have to convert the time to the regional time because the rotation start and end times
        # are not converted to UTC before storing them in the database
        if check_datetime is None:
            check_datetime = times.get_current_timestamp()

        on_call = []
        non_exceptions, exceptions = self.separate_routines_and_exceptions()

        # First check for exceptions before looking at the usual routines.
        # Exceptions take precedence over routines.
        if len(exceptions) > 0:
            for exception_layer in exceptions:
                on_call += exception_layer.get_on_call(check_datetime, self.routine_timezone)
            if len(on_call) > 0:
                return on_call

        for layer in self.routine_layers:
            on_call += layer.get_on_call(check_datetime, self.routine_timezone)
        return on_call

    def is_on_call(self, assignee_pol_id, check_datetime=times.get_current_timestamp()):
        '''
        Checks if an assignee is on-call at a certain point in time.
        :param assignee_pol_id: policy id of the assignee to check for
        :param check_datetime: (timestamp) the timestamp to check on
        :return: (boolean) -> True if assignee is on-call; False otherwise
        '''
        on_call = self.get_on_call(check_datetime)
        if assignee_pol_id in [x[2] for x in on_call]:
            return True
        return False

    def has_hand_off(self, assignee_pol_id, check_datetime):
        '''
        Checks if the on-call role is actually being handed off or not.
        :param assignee_pol_id: policy ID of the user
        :param check_datetime: (datetime.datetime) to check on
        :return: (boolean) True if there is a hand-off; False otherwise
        '''
        is_on_call = self.is_on_call(assignee_pol_id, check_datetime - datetime.timedelta(minutes=1))
        will_be_on_call = self.is_on_call(assignee_pol_id, check_datetime + datetime.timedelta(minutes=1))

        if (is_on_call and not will_be_on_call) or (not is_on_call and will_be_on_call):
            return True
        return False

    def is_assignee_going_on_call(self, assignee_pol_id, check_datetime):
        '''
        Checks if a user is about to go on-call or not.
        :param assignee_pol_id: policy ID of the user
        :param check_datetime: (datetime.datetime) to check on
        :return: (boolean) True if there is a hand-off; False otherwise
        '''
        is_on_call = self.is_on_call(assignee_pol_id, check_datetime - datetime.timedelta(minutes=1))
        will_be_on_call = self.is_on_call(assignee_pol_id, check_datetime + datetime.timedelta(minutes=1))

        if not is_on_call and will_be_on_call:
            return True
        return False

    def is_assignee_coming_off_on_call(self, assignee_pol_id, check_datetime):
        '''
        Checks if a user is about to come off their on-call role.
        :param assignee_pol_id: policy ID of the user
        :param check_datetime: (datetime.datetime) to check on
        :return: (boolean) True if there is a hand-off; False otherwise
        '''
        is_on_call = self.is_on_call(assignee_pol_id, check_datetime - datetime.timedelta(minutes=1))
        will_be_on_call = self.is_on_call(assignee_pol_id, check_datetime + datetime.timedelta(minutes=1))

        if is_on_call and not will_be_on_call:
            return True
        return False

    def get_assignee_on_call_period(self, assignee_pol_id, check_datetime, look_forward=30):
        '''
        Get the full length of the continuous on-call period of a user.
        :param assignee_pol_id: policy ID of the user
        :param check_datetime: (datetime.datetime) to check on
        :param look_forward: (int) number of days to look forward for the upcoming on-call roles
        :return: (tuple) (current period start timestamp, current period end timestamp),
                        (next period start timestamp, next period end timestamp)
        '''
        reg_check_dt = times.utc_to_region_time(check_datetime, self.routine_timezone)
        period_start = reg_check_dt - datetime.timedelta(days=8)
        rou_schedules = self.prepare_schedule(period_start, look_forward)

        curr_period, next_period = None, None
        for sch in rou_schedules:
            sch_start = sch[var_names.rotation_start]
            sch_end = sch[var_names.rotation_end]

            if curr_period is None and sch_start <= reg_check_dt < sch_end:
                for rot in sch[var_names.on_call]:
                    if rot.assignee_policy_id == assignee_pol_id:
                        curr_period = (sch_start, sch_end)
                        break

            if next_period is None and sch_start > reg_check_dt:
                for rot in sch[var_names.on_call]:
                    if rot.assignee_policy_id == assignee_pol_id:
                        next_period = (sch_start, sch_end)
                        break

        return curr_period, next_period

    def get_all_assignees(self, skip_exceptions=True):
        '''
        Get the user_ids of all the users who are assignees of the routine (from all layers).
        :param skip_exceptions: (boolean) True if exceptions should be ignored; False otherwise
        :return: (list of tuples) -> [(user_id, display name), ...] of users/assignees
        '''
        assignees = []
        if skip_exceptions:
            routine_layers, exceptions = self.separate_routines_and_exceptions()
        else:
            routine_layers = self.routine_layers

        for layer in routine_layers:
            assignees += layer.get_all_assignees()
        return assignees

    def remove_rotations(self, user_names: list):
        '''
        Remove the rotations that are associated to the given user_ids from all the layers of the routine.
        :param user_names: (list) of user_ids to check for in the rotations and remove
        '''
        sorted_layers = sorted(self.routine_layers, key=lambda x: x.layer, reverse=False)
        new_layers = []
        drop_by = 0
        for layer in sorted_layers:
            layer.remove_rotations(user_names)
            if len(layer.rotations) > 0:
                # if any layer was deleted completely due to the removal of the rotations,
                # then update the layer number of the layer and its rotations
                if drop_by > 0:
                    new_layer_num = layer.layer - drop_by
                    layer.layer = new_layer_num
                    for item in layer.rotations:
                        item.level = new_layer_num
                new_layers.append(layer)
            else:
                drop_by += 1
        self.routine_layers = new_layers

    def separate_routines_and_exceptions(self):
        '''
        Get exception layers and non-exceptions in separate lists.
        :return: (tuple) -> (exceptions, non_exceptions)
        '''
        exceptions = []
        non_exceptions = []
        for layer in self.routine_layers:
            if layer.is_exception:
                exceptions.append(layer)
            else:
                non_exceptions.append(layer)
        return non_exceptions, exceptions

    def prepare_schedule(self, start_date, period):
        end_date = start_date + datetime.timedelta(days=period)
        reg_layers, exc_layers = self.separate_routines_and_exceptions()
        regular_schedule = []
        exceptions = []
        final_schedule = []

        for reg in reg_layers:
            reg.localize_valid_times(self.routine_timezone)
            regular_schedule = regular_schedule + reg.prepare_schedule(start_date, period)
            reg.standardize_valid_times_to_utc(self.routine_timezone)

        for exc in exc_layers:
            exc.localize_valid_times(self.routine_timezone)
            exceptions = exceptions + exc.prepare_schedule(start_date, period)
            exc.standardize_valid_times_to_utc(self.routine_timezone)

        gaps = []
        if len(exceptions) > 0:
            gaps = self.find_gaps_between_schedules(exceptions, start_date, end_date)

        for curr_sch in regular_schedule:
            if len(exceptions) > 0 and len(gaps) > 0:
                for gap in gaps:
                    adj_sch = self.get_gap_adjusted_schedule(curr_sch, gap)
                    if adj_sch is not None:
                        final_schedule.append(adj_sch)
            else:
                final_schedule.append(curr_sch)

        final_schedule = final_schedule + exceptions
        final_schedule = helpers.sorted_list_of_dict(final_schedule, var_names.rotation_start)
        return final_schedule

    @staticmethod
    def get_gap_adjusted_schedule(sch, gap):
        item_start = sch[var_names.rotation_start]
        item_end = sch[var_names.rotation_end]
        gap_start = gap[0]
        gap_end = gap[1]

        # Case where the item is fully consumed by the gap is taken care of by not doing anything. The condition is:
        # (item_start = gap_start and item_end == gap_end) or (item_start >= gap_start and item_end <= gapEnd)

        # item and gap do not overlap at all
        if item_end <= gap_start or item_start >= gap_end:
            return None

        # gap is completely consumed by item
        elif gap_start > item_start and gap_end < item_end:
            item_start = gap_start
            item_end = gap_end

        # item overlaps with gap partially from the left
        elif item_start <= gap_start < item_end <= gap_end:
            item_start = gap_start

        # item overlaps with gap partially from the right; item protrudes outwards
        elif gap_start <= item_start < gap_end <= item_end:
            item_end = gap_end

        adj_sch = {
            var_names.rotation_start: item_start,
            var_names.rotation_end: item_end,
            var_names.on_call: sch[var_names.on_call]
        }
        return adj_sch

    @staticmethod
    def find_gaps_between_schedules(schedules, period_start, period_end):
        if len(schedules) == 0:
            return []

        schedules = helpers.sorted_list_of_dict(schedules, var_names.rotation_start)
        gaps = []
        prev_start = schedules[0][var_names.rotation_start]
        prev_end = schedules[0][var_names.rotation_end]
        if prev_start > period_start:
            gaps.append((period_start, prev_start))

        for i in range(1, len(schedules)):
            sch = schedules[i]
            sch_start = sch[var_names.rotation_start]
            sch_end = sch[var_names.rotation_end]

            if sch_start <= prev_end < sch_end:
                prev_start = sch_start
                prev_end = sch_end
            elif sch_start <= prev_end and sch_end <= prev_end:
                continue
            elif sch_start > prev_end and sch_end > prev_end:
                gaps.append((prev_end, sch_start))
                prev_start = sch_start
                prev_end = sch_end

        if prev_end < period_end:
            gaps.append((prev_end, period_end))

        return gaps

    def to_dict(self, basic_info=False):
        '''
        Gets the dict of the Routine object.
        :param basic_info: True if only rotation name and assignee preferred username are required from the rotations
        :return: dict of Routine object
        '''

        # DO NOT pass basic_info = True directly on to the Routine object from PolicyLevel because the basic info
        # returned by routines is different from what is required by policies.

        layers_dict_list = []
        for item in self.routine_layers:
            layers_dict_list.append(item.to_dict(basic_info))

        data = {var_names.routine_id: self.routine_id,
                var_names.routine_ref_id: self.reference_id,
                var_names.routine_name: self.routine_name,
                var_names.timezone: self.routine_timezone,
                var_names.routine_layers: layers_dict_list,
                var_names.associated_policies: self.associated_policies}

        if not basic_info:
            data[var_names.organization_id] = self.organization_id

        return data
