#!/usr/bin/python3
# JUNE 2020	chrisw

# info about accuracy metrics for imbalanced class_scores

# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html#sklearn.metrics.multilabel_confusion_matrix
# https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix
# https://scikit-learn.org/stable/modules/model_evaluation.html#balanced-accuracy-score
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html

# https://www.datascienceblog.net/post/machine-learning/performance-measures-multi-class-problems
# https://rushdishams.blogspot.com/2011/08/micro-and-macro-average-of-precision.html
# https://datascience.stackexchange.com/questions/15989/micro-average-vs-macro-average-performance-in-a-multiclass-classification-settin
# https://kaybrodersen.github.io/publications/Carrillo_2014_ROBOT2013.pdf


# imports
import logging
from optparse import OptionParser
import sys
import csv
import json
import os
import re

import datetime

import sklearn.metrics

# global vars
verbose = True

# methods and functions


def getOptions():
    "parse options"
    usage_text = []
    usage_text.append("%prog [options]")

    description_text = []
    description_text.append(
        "Read in classification model results in the TMP format and compute classification performance metrics using sklearn.metrics.")
    description_text.append(
        "Reads input from file passed in as argument.")
    description_text.append(
        "Outputs a tsv file with prediction performance metrics for both training set and test set.")

    parser = OptionParser(usage="\n".join(usage_text),
                          description="\n".join(description_text))

    parser.add_option("-v", action="store_true", default=False,
                      dest="verbose", help="Switch for verbose mode.")

    (options, args) = parser.parse_args()

    return (options, args, parser)


def getNow():
    """
    Get a datetime object for utc NOW.
    Convert to ISO 8601 format with datetime.isoformat()
    """
    now = datetime.datetime.utcnow()
    return now


def getTimeDelta(startDatetime):
    """
    get a timedelta object. Get seconds elapsed with timedelta.total_seconds().
    """
    endDatetime = datetime.datetime.utcnow()
    timedeltaObj = endDatetime - startDatetime
    return timedeltaObj


def get_sample_predictions_dict_reader(input):
    comment_lines = []
    data_lines = []
    comment_pattern = re.compile("^#")
    for line in input.readlines():
        line = line.rstrip("\r\n")
        if comment_pattern.match(line):
            comment_lines.append(line)
        else:
            data_lines.append(line)
    reader = csv.DictReader(data_lines, delimiter="\t")

    logging.debug("number of comment lines: %d" % (len(comment_lines)))
    logging.debug("number of data lines: %d" % (len(data_lines)))
    logging.debug("\tnumber of fields:%d\n\tfield names: %s" %
                  (len(reader.fieldnames), str(reader.fieldnames)))
    return (comment_lines, reader)


def correct_label(input_label):
    new_label = re.sub('_', ':', input_label, 1)
    return new_label


def wrangle_dict_reader(data_dict_reader):
    label_vector_dict = {}
    # label_vector_dict[repeat_fold]["test"][model_name]["true_labels"]

    # organize data into chunks by cv_fold
    # first 5 fields should be: 'Sample_ID', 'Repeat', 'Fold', 'Test', 'Label'
    # fields 6+ are model names
    fieldnames = data_dict_reader.fieldnames
    label_counts_dict = {}
    predicted_labels_set = set()
    for row in data_dict_reader:
        sampleID = row[fieldnames[0]]
        repeat = row[fieldnames[1]]
        fold = row[fieldnames[2]]
        test = row[fieldnames[3]]
        label = row[fieldnames[4]]

        if label not in label_counts_dict.keys():
            label_counts_dict[label] = set()
        else:
            pass

        label_counts_dict[label].add(sampleID)

        repeat_fold = "R%s:F%s" % (repeat, fold)

        for model_name in fieldnames[5:]:

            try:
                # When model outputs scores for possible classes, the results are given in a JSON
                prediction_dict = json.loads(row[model_name])
                class_scores = prediction_dict["classification"]
                predicted_class = max(class_scores, key=class_scores.get)
            except:
                # default to this code block if json.loads fails
                # when model outputs just a predicted class, the result is given as a string instead of JSON
                predicted_class = row[model_name]

            predicted_labels_set.add(predicted_class)

            if repeat_fold not in label_vector_dict.keys():
                label_vector_dict[repeat_fold] = {"train": {}, "test": {}}
            else:
                pass

            if (model_name not in label_vector_dict[repeat_fold]["train"].keys()):
                label_vector_dict[repeat_fold]["train"][model_name] = {
                    "true_labels": [], "predicted_labels": []}
            else:
                pass

            if (model_name not in label_vector_dict[repeat_fold]["test"].keys()):
                label_vector_dict[repeat_fold]["test"][model_name] = {
                    "true_labels": [], "predicted_labels": []}
            else:
                pass

            if test == "1":
                label_vector_dict[repeat_fold]["test"][model_name]["true_labels"].append(
                    label)
                label_vector_dict[repeat_fold]["test"][model_name]["predicted_labels"].append(
                    predicted_class)
            elif test == "0":
                label_vector_dict[repeat_fold]["train"][model_name]["true_labels"].append(
                    label)
                label_vector_dict[repeat_fold]["train"][model_name]["predicted_labels"].append(
                    predicted_class)
            else:
                logging.warn("invalid value for 'test' field: %s" %
                             (str(test)))

    logging.debug("labels encountered: %s" %
                  (str(set(label_counts_dict.keys()))))

    predicted_labels_only_set = predicted_labels_set.difference(
        set(label_counts_dict.keys()))
    if (len(predicted_labels_only_set) != 0):
        logging.warning("predicted label was not found in the true label set: %s" % (
            str(predicted_labels_only_set)))
    else:
        pass

    return (label_vector_dict, label_counts_dict)


def compute_performance_metrics(label_vector_dict, label_counts_dict):
    # label_vector_dict[repeat_fold][train_test][model_name]["true_labels" | "predicted_labels"] = vector of labels
    # label_counts_dict[label] = count

    # compute performance with scikit-learn
    logging.info("computing performance")
    result_dict_list = []
    for repeat_fold in label_vector_dict.keys():
        for train_test in ["train", "test"]:
            for model_name in label_vector_dict[repeat_fold][train_test].keys():
                true_labels = label_vector_dict[repeat_fold][train_test][model_name]["true_labels"]
                predicted_labels = label_vector_dict[repeat_fold][train_test][model_name]["predicted_labels"]

                # cohort performance
                bacc = sklearn.metrics.balanced_accuracy_score(
                    true_labels, predicted_labels)

                acc = sklearn.metrics.accuracy_score(
                    true_labels, predicted_labels)

                (precision, recall, f1_score, support) = sklearn.metrics.precision_recall_fscore_support(
                    true_labels, predicted_labels, average='weighted', zero_division=0)

                result_dict = {}
                result_dict["repeat_fold"] = repeat_fold
                result_dict["test"] = train_test
                result_dict["model"] = model_name
                result_dict["overall_bacc"] = bacc
                result_dict["overall_acc"] = acc
                result_dict["overall_weighted_f1"] = f1_score
                result_dict["overall_precision"] = precision
                result_dict["overall_recall"] = recall

                # subtype performance
                label_bacc_list = []
                label_acc_list = []
                label_hit_miss_list = []
                label_f1_list = []
                label_precision_list = []
                label_recall_list = []

                for label_key in label_counts_dict.keys():
                    binarized_true_labels = []
                    binarized_predicted_labels = []
                    for true_label in true_labels:
                        binarized_label = binarize_label(
                            true_label, label_key)
                        binarized_true_labels.append(binarized_label)
                        # logging.debug("true_label:new_label\t%s:%s" %
                        #               (true_label, new_label))
                    for predicted_label in predicted_labels:
                        binarized_label = binarize_label(
                            predicted_label, label_key)
                        binarized_predicted_labels.append(binarized_label)
                        # logging.debug("predicted_label:new_label\t%s:%s" % (
                        #     predicted_label, new_label))

                    label_name_size = "%s_%d_samples" % (
                        correct_label(label_key), len(label_counts_dict[label_key]))

                    # bacc
                    # this seems to produce results similar to caret's.
                    bacc = sklearn.metrics.balanced_accuracy_score(
                        binarized_true_labels, binarized_predicted_labels)

                    result_dict["%s_subtype_bacc" % (label_name_size)] = bacc
                    label_bacc_list.append(bacc)

                    # acc
                    acc = sklearn.metrics.accuracy_score(
                        binarized_true_labels, binarized_predicted_labels)

                    result_dict["%s_subtype_acc" % (label_name_size)] = acc
                    label_acc_list.append(acc)

                    # hit/miss calculation ... this is actually recall.
                    hit_miss_score = compute_hit_miss(
                        true_labels, predicted_labels, label_key)

                    result_dict["%s_hit_miss" %
                                (label_name_size)] = hit_miss_score
                    label_hit_miss_list.append(hit_miss_score)

                    # f1 score

                    (precision, recall, f1, support) = sklearn.metrics.precision_recall_fscore_support(
                        true_labels, predicted_labels, labels=[label_key], average=None, zero_division=0)

                    result_dict["%s_subtype_precision" %
                                (label_name_size)] = precision[0]

                    result_dict["%s_subtype_recall" %
                                (label_name_size)] = recall[0]

                    result_dict["%s_subtype_f1" %
                                (label_name_size)] = f1[0]

                    result_dict["%s_subtype_support" %
                                (label_name_size)] = support[0]

                    label_f1_list.append(f1[0])
                    label_precision_list.append(precision[0])
                    label_recall_list.append(recall[0])

                result_dict["arith_mean_of_subtype_bacc"] = compute_arith_mean_of_list(
                    label_bacc_list)

                result_dict["arith_mean_of_subtype_acc"] = compute_arith_mean_of_list(
                    label_acc_list)

                result_dict["arith_mean_of_hit_miss"] = compute_arith_mean_of_list(
                    label_hit_miss_list)

                result_dict["arith_mean_of_subtype_f1"] = compute_arith_mean_of_list(
                    label_f1_list)

                result_dict["arith_mean_of_subtype_precision"] = compute_arith_mean_of_list(
                    label_precision_list)

                result_dict["arith_mean_of_subtype_recall"] = compute_arith_mean_of_list(
                    label_recall_list)

                result_dict_list.append(result_dict)

    return result_dict_list


def compute_arith_mean_of_list(number_list):

    arith_mean = float(
        sum(number_list) / len(number_list))

    return arith_mean


def binarize_label(label, pos_label):
    if label == pos_label:
        return 1
    else:
        return 0


def compute_hit_miss(true_labels, predicted_labels, label_to_score):
    # kyleE used this score in his scoring.pbynb for generating the master list
    # including it here since it was used in previous analysis
    hits = 0
    misses = 0
    for index in range(0, len(true_labels)):
        true_label = true_labels[index]
        predicted_label = predicted_labels[index]

        if true_label == label_to_score:
            # score hit or miss
            if true_label == predicted_label:
                hits += 1
            else:
                misses += 1
        else:
            # ignore
            pass

    score = hits / (hits + misses)

    return score


def output_model_performance(outputDicts, comment_lines=[]):
    fieldNames = ["repeat_fold", "test",
                  "model", "overall_acc", "overall_bacc", "overall_weighted_f1", "overall_precision", "overall_recall", "arith_mean_of_hit_miss", "arith_mean_of_subtype_acc", "arith_mean_of_subtype_bacc", "arith_mean_of_subtype_f1", "arith_mean_of_subtype_precision", "arith_mean_of_subtype_recall"]

    dict_keys = outputDicts[0].keys()
    other_fields = set(dict_keys).difference(set(fieldNames))

    other_fields = sorted(list(other_fields))
    fieldNames.extend(other_fields)

    logging.debug("fieldNames:%s" % (str(fieldNames)))

    writer = csv.DictWriter(sys.stdout, fieldnames=fieldNames, delimiter="\t",
                            lineterminator="\n", extrasaction="ignore")
    writer.writeheader()
    writer.writerows(outputDicts)
    return None

#:####################################


def main():
    startTime = getNow()
    (options, args, parser) = getOptions()

    if options.verbose:
        logLevel = logging.DEBUG
    else:
        logLevel = logging.INFO
    logFormat = "%(asctime)s %(levelname)s %(funcName)s:%(lineno)d %(message)s"
    logging.basicConfig(level=logLevel, format=logFormat)

    logging.debug('options:\t%s' % (str(options)))
    logging.debug('args:\t%s' % (str(args)))

    # DO SOMETHING HERE

    if len(args) == 1:
        logging.debug("read data from file: %s" % (args[0]))
        input = open(args[0], "r")
    else:
        logging.error("Expected exactly one argument, the input file")
        sys.exit(1)

    # read file into a dict_reader
    (comment_lines, data_dict_reader) = get_sample_predictions_dict_reader(input)

    input.close()

    # wrangle the data
    (label_vector_dict, label_counts_dict) = wrangle_dict_reader(data_dict_reader)

    # compute metrics
    model_performance_dict_list = compute_performance_metrics(
        label_vector_dict, label_counts_dict)

    s = str(model_performance_dict_list[0])
    logging.debug("model_performance_dict_list[0]:%s" % (s))

    # output to file
    output_model_performance(model_performance_dict_list,
                             comment_lines=comment_lines)

    runTime = getTimeDelta(startTime).total_seconds()
    logging.debug("%s ran for %s s." %
                  (os.path.basename(__file__), str(runTime)))
    logging.shutdown()
    return None


# main program section
if __name__ == "__main__":
    main()
