# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT

import os
from pathlib import Path

from cli.args import ValidMultiFileSpec, get_sample_range_args
from cli.output_handler import OutputHandler
from cli.timer import timer
from cli.tps import TPSViewGenerator
from cli.writers.csv import CSVWriterFactory
from cli.writers.excel.excel_writer_rust import RustExcelWriter
from mpp import DataParserFactory, ViewAggregationLevel, SymbolTableFactory, MetricDefinitionParserFactory, \
    MetricComputer
from mpp.core.api_args import ApiArgs, SystemInformation
from mpp.core.configuration_path_generator import ConfigurationPathGenerator
from mpp.core.types import EventInfoDataFrameColumns as eidc, ConfigurationPaths


class MppCli:

    def __init__(self, cli_args):
        self.__cli_args = cli_args
        self.__parser = None
        self.__system_info = None
        self.__output_handler = None
        self.__configuration_file_paths = None
        self.__tps_views = []
        self.__partitions = []
        self.__first_sample = 0
        self.__last_sample = 0

    def initialize(self):
        self.__parser = DataParserFactory.create(self.__cli_args.input_data_file_path, self.__cli_args.frequency)
        self.__system_info = self.__parser.system_info
        self.__output_handler = OutputHandler(self.__cli_args.output_file_specifier)
        self.__get_configuration_file_paths()
        self.__get_partition_info()

    @property
    def parser(self):
        return self.__parser
    @property
    def num_unique_events(self):
        return len(set(self.__parser.event_info[eidc.NAME]))
    
    @property
    def first_sample(self):
        return self.__first_sample

    @property
    def last_sample(self):
        return self.__last_sample
    
    @property
    def partitions(self):
        return self.__partitions

    @property
    def is_parallel(self):
        return self._conditions_met_for_parallel_processing()

    @property
    def num_partitions(self):
        return len(self.__partitions)

    def get_api_args(self):

        self.__validate_core_types()

        system_information = SystemInformation(
            processor_features=self.__system_info.processor_features,
            system_features=self.__system_info.system_features,
            uncore_units=self.__system_info.uncore_units,
            ref_tsc=self.__system_info.ref_tsc,
            unique_core_types=self.__system_info.unique_core_types,
            has_modules=self.__system_info.has_modules
        )

        api_args = ApiArgs(
            system_information=system_information,
            retire_latency=self.__cli_args.retire_latency,
            aggregation_levels=self.__get_requested_aggregation_levels(),
            metric_computer_map=self.__get_metric_computer_map(),
            is_parallel=self.is_parallel,
            no_detail_views=self.__cli_args.no_detail_views,
            percentile=self.__cli_args.percentile,
            output_directory=self.__output_handler.output_directory,
            event_info=self.__parser.event_info,
            output_prefix=self.__output_handler.output_file_prefix,
            output_writers=self.__get_output_writers()
        )
        return api_args

    def generate_tps_views(self, view_writer, summary_views):
        if self.__cli_args.transactions_per_second:
            with timer() as number_of_seconds:
                tps_generator = TPSViewGenerator(self.__cli_args.transactions_per_second)
                tps_summary_views = tps_generator.generate_summaries(summary_views)
                view_writer.write(list(tps_summary_views.values()), self.first_sample, self.last_sample)
                self.__tps_views = [value.attributes for value in list(tps_summary_views.values())]
                print(f'Generated all Transactions Per Second tables in {number_of_seconds}')

    def write_excel_output(self, view_collection):
        if self.__output_handler.excel_file_name:
            include_details = not self.__cli_args.no_detail_views
            include_charts = True if self.__cli_args.chart_format_file_path else False
            self.__append_tps_views(view_collection)

            print('Creating Excel output file...')
            if self.__cli_args.enable_rust and not include_charts:
                self.__write_excel_rust(include_charts, include_details, view_collection)
            else:
                self.__write_excel_python(view_collection)

    def __append_tps_views(self, view_collection):
        if self.__cli_args.transactions_per_second:
            view_collection.append_views(self.__tps_views)

    def __write_excel_python(self, view_collection):
        from cli.writers.excel import excel_writer_python
        excel_writer_python.write_csv_data_to_excel(self.__cli_args, view_collection, self.__output_handler)
        print(f'Output written to: {self.__cli_args.output_file_specifier}')

    def __write_excel_rust(self, include_charts, include_details, view_collection):
        excel_file = self.__output_handler.excel_file_with_path
        rxlsx = RustExcelWriter(self.__output_handler.output_directory, excel_file, include_details, include_charts)
        rxlsx.write_csv_to_excel(view_collection)

    def __validate_core_types(self):
        ValidMultiFileSpec.validate_multi_file_core_types(self.__cli_args.metric_file_path,
                                                          self.__system_info.unique_core_types)

    def __get_requested_aggregation_levels(self):
        requested_aggregation_levels = [ViewAggregationLevel.SYSTEM]  # Always generate the system views
        for agg_level in [self.__cli_args.socket_view, self.__cli_args.core_view, self.__cli_args.thread_view, self.__cli_args.uncore_view]:
            if agg_level:
                requested_aggregation_levels.append(agg_level)
        return requested_aggregation_levels

    def __get_metric_computer_map(self):
        metric_computer_map = {}
        for core_type, metric_file in self.__cli_args.metric_file_path.items():
            metric_computer_map[core_type] = self.__get_metric_computer(core_type, metric_file)
        return metric_computer_map

    def __get_configuration_file_paths(self):
        configuration_file_finder = ConfigurationPathGenerator(self.__system_info.configuration_file_paths,
                                                               [self.__cli_args.metric_file_path,
                                                                self.__cli_args.chart_format_file_path],
                                                               self.__system_info.unique_core_types)
        self.__configuration_file_paths = configuration_file_finder.generate()
        self.__set_chart_metric_file_paths()

    def __set_chart_metric_file_paths(self):
        if self.__cli_args.chart_format_file_path:
            self.__cli_args.chart_format_file_path = {core_type: self.__configuration_file_paths[core_type][
                ConfigurationPaths.CHART_PATH] for core_type in self.__configuration_file_paths if
                                                      self.__configuration_file_paths[core_type][
                                                          ConfigurationPaths.CHART_PATH]}
        self.__cli_args.metric_file_path = {
            core_type: self.__configuration_file_paths[core_type][ConfigurationPaths.METRIC_PATH]
            for core_type in self.__configuration_file_paths if
            self.__configuration_file_paths[core_type][ConfigurationPaths.METRIC_PATH]}

    def __get_metric_computer(self, core_type, metric_file):
        symbol_table = SymbolTableFactory.create(self.__system_info, core_type, self.__cli_args.retire_latency)
        metric_definitions = MetricDefinitionParserFactory.create(Path(metric_file)).parse()
        metric_computer = MetricComputer(metric_definitions, symbol_table)
        return metric_computer

    def __get_output_writers(self):
        use_polars = self.is_parallel and not self.__cli_args.disable_polars
        csv_writer = CSVWriterFactory.create(use_polars, Path(self.__output_handler.output_directory))
        return [csv_writer]

    def __get_partition_info(self):
        self.__partitions = self.parser.partition(chunk_size=self.__cli_args.chunk_size, **get_sample_range_args(self.__cli_args))
        self.__first_sample = self.__partitions[0].first_sample
        self.__last_sample = self.__partitions[-1].last_sample

    def _conditions_met_for_parallel_processing(self):
        if self.__cli_args.force_parallel:
            return True
        return (os.path.getsize(self.__cli_args.input_data_file_path) / 1000000 > 6 and self.num_partitions > 1 and
                self.__cli_args.parallel_cores != 1)
