import string
from abc import ABC
from collections import Counter, defaultdict
from itertools import chain, combinations, product

import graphviz
import numpy as np
import pandas as pd
from matplotlib import cm

CMAP = cm.get_cmap("coolwarm")


class Election(ABC):
    """Voting pool handling"""

    def get_candidates(self):
        """Return an iterable over candidates"""
        pass

    def vote(self, ballot):
        """Take a vote into account

        Args:
            ballot: vote
        """
        pass

    def plot(self):
        """return a comprehensive plot of results"""

    def get_winner(self) -> str:
        """Return the winner's name"""
        pass


class MajorityJudgement(Election):
    notes = ("très bien", "bien", "assez bien", "passable", "insuffisant", "à rejeter")

    def __init__(self, candidates=None):
        if candidates is None:
            candidates = []

        self.results = pd.DataFrame(columns=self.notes, index=set(candidates))
        self.results = self.results.fillna(0)

    def add_candidate(self, name):
        """Add a candidate to possibilities.

        This quick and dryt implementation requires all caidate to be known
        before the vote.
        """
        if name in self.get_candidates():
            return
        self.results.loc[name, :] = [0] * len(self.notes)

    def get_candidates(self):
        return list(self.results.index)

    def vote(self, ballot):
        """Add a ballot

        Args:
            ballot (dict): {candidate: note}
        """
        if not all(note in self.notes for note in ballot.values()):
            raise ValueError("unknown note")
        if not all(candidate in self.get_candidates() for candidate in ballot.keys()):
            raise ValueError("unknown candidate")
        if set(ballot.keys()) != set(self.get_candidates()):
            raise ValueError("there must be one and only one vote per candidate")

        for candidate, note in ballot.items():
            self.results.loc[candidate, note] += 1

    def plot(self, *args, **kwargs):
        total = sum(self.results.iloc[0, :])
        ax = self.results.plot(kind="bar", stacked=True, cmap=CMAP, **kwargs)
        ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
        ax.axhline(total / 2, color="black", ls="--")
        return ax

    def get_mentions(self, candidat):
        """Get result for one candidat

        Returns:
            (dict) {mention: number of ballot giving at least this mention}
        """
        return dict(np.cumsum(self.results.loc[candidat, :]))

    def get_winner(self):
        best_score = max(self.get_best_candidates().values())
        return [
            candidat
            for candidat, value in self.get_best_candidates().items()
            if value == best_score
        ]

    def get_best_candidates(self):
        """Return candidates whose majority mention is the best realized"""
        best_note = self.get_best_note()
        return {
            candidat: mention[best_note]
            for candidat, mention in self.get_majority_mentions().items()
            if best_note in mention
        }

    def get_majority_mention(self, candidat):
        """Get majority mention for the candidat."""
        total = sum(self.results.iloc[0, :])
        mentions = self.get_mentions(candidat)
        for note in self.notes:
            if mentions.get(note, 0) >= total / 2:
                return {note: mentions[note]}

    def get_majority_mentions(self) -> dict:
        """get majority mention for all candidats."""
        return {
            candidat: self.get_majority_mention(candidat)
            for candidat in self.get_candidates()
        }

    def get_best_note(self):
        realised_results = self.get_majority_mentions()
        realised_notes = []
        for result in realised_results.values():
            realised_notes.extend(result.keys())

        for note in self.notes:
            if note in realised_notes:
                return note


class Condorcet(Election):
    def __init__(self):
        self.duels = defaultdict(int)

    def vote(self, ballot):
        """

        Args:
            ballot (list) : order list of candidates. the first one is prefered
        """
        scores = zip(ballot, range(len(ballot)))
        for (name0, score0), (name1, score1) in combinations(scores, 2):
            if score0 > score1:
                # candidate1 wins
                self.duels[(name1, name0)] += 1
            else:
                self.duels[(name0, name1)] += 1

    def get_raw_matrix(self):
        """ "get duel matrix filled with number of ballots"""

        candidates = set(name for names in self.duels.keys() for name in names)
        duel_matrix = pd.DataFrame(index=candidates, columns=candidates)
        for (row, col), score in self.duels.items():
            duel_matrix.loc[row, col] = score
        duel_matrix.fillna(0, inplace=True)
        return duel_matrix

    def get_duel_matrix(self):
        """Get duel matrix

        if cadidate0 win more duel over candidate1 thant the other way arround,
        the cell [candidate0, cadidate1] = 1.
        """
        candidates = self.get_candidates()
        duel_matrix = pd.DataFrame(index=candidates, columns=candidates)
        for row, col in product(candidates, candidates):
            if self.duels[(row, col)] > self.duels[(col, row)]:
                duel_matrix.loc[row, col] = 1
            if self.duels[(row, col)] < self.duels[(col, row)]:
                duel_matrix.loc[row, col] = -1
            if self.duels[(row, col)] == self.duels[(col, row)]:
                duel_matrix.loc[row, col] = 0
        duel_matrix["total_win"] = duel_matrix.apply(sum, axis=1)
        return duel_matrix.applymap(int)

    def get_prefered(self):
        """return t=rpeference list"""
        scores = self.get_duel_matrix()["total_win"].to_dict()
        results = Counter(scores)

        return [name for name, _ in results.most_common()]

    def get_winner(self):
        return self.get_duel_matrix()["total_win"].idxmax()

    def get_candidates(self):
        return set(name for names in self.duels.keys() for name in names)

    def get_dominance_graph(self):
        dot = graphviz.Digraph()

        # build nodes
        node_names = list(zip(self.get_candidates(), string.ascii_letters))
        for label, name in node_names:
            dot.node(name, label)

        # build edges
        duel_mat = self.get_duel_matrix()
        for (name0, label0), (name1, label1) in combinations(node_names, 2):
            if duel_mat.loc[name0, name1] == 1:
                dot.edge(label0, label1)
            if duel_mat.loc[name1, name0] == 1:
                dot.edge(label1, label0)

        return dot

    def plot(self):
        return self.get_dominance_graph()


class Borda(Election):
    """Borda's method"""

    def __init__(self):
        self.candidates = Counter()
        self.points = None  # points attributes to each candidates

    def get_candidates(self):
        return self.candidates.keys()

    def vote(self, ballot: list):
        if self.points is None:
            self.points = list(range(len(ballot), 0, -1))
        for name, score in zip(ballot, self.points):
            self.candidates[name] += score

    def get_results(self):
        return self.candidates

    def plot(self, **kwargs):
        results = pd.Series(self.get_results()).sort_values(ascending=False)
        return results.plot(kind="bar", **kwargs)

    def get_winner(self) -> str:
        return self.candidates.most_common()[0][0]


class Approval(Election):
    """Approval voting"""

    def __init__(self):
        self.candidates = Counter()
        self.ballots = []

    def get_candidates(self):
        return self.candidates.keys()

    def vote(self, ballot: list):
        """Add ballot to vote.

        Args:
            ballot (list): list of approved candidates
        """
        self.ballots.append(ballot)

    def get_results(self):
        self.candidates = Counter(chain(*self.ballots))
        self.voters = len(self.ballots)
        return self.candidates.most_common()

    def plot(self, **kwargs):
        results = pd.Series(dict(self.get_results()))
        names = [
            "{} ({})".format(name, vote) for name, vote in results.to_dict().items()
        ]
        results.index = names
        results = results / self.voters * 100
        return results.plot(kind="bar", **kwargs, ylabel="%")

    def get_winner(self):
        return self.candidates.most_common(1).keys()


def jm_to_order(jm_vote):
    """transform vote for "juement majoritaire" to prefefrence order."""
    notes = MajorityJudgement.notes
    values = dict(zip(notes, range(7)))
    prefs = [(name, values[appr]) for name, appr in jm_vote.items()]
    prefs.sort(key=lambda x: x[1])
    return [name for name, _ in prefs]


def jm_to_approval(jm_vote):
    """transform vote for "juement majoritaire" to approval list."""
    notes = MajorityJudgement.notes
    values = dict(zip(notes, range(7)))
    return [name for name, appr in jm_vote.items() if values[appr] < 4]
