#!/usr/bin/env python3
# pylint: disable=C0103,C0114,C0116,C0301,R0914,R0912,R0915,W0511,eval-used
######################################################################

import argparse
import collections
import math
import re
import statistics
# from pprint import pprint

Threads = collections.defaultdict(lambda: collections.defaultdict(lambda: {}))
Mtasks = collections.defaultdict(lambda: {})
Evals = collections.defaultdict(lambda: {})
EvalLoops = collections.defaultdict(lambda: {})
Global = {
    'args': {},
    'cpuinfo': collections.defaultdict(lambda: {}),
    'rdtsc_cycle_time': 0,
    'stats': {}
}

######################################################################


def process(filename):
    read_data(filename)
    report()


def read_data(filename):
    with open(filename) as fh:
        re_thread = re.compile(r'^VLPROFTHREAD (\d+)$')
        re_record = re.compile(r'^VLPROFEXEC (\S+) (\d+)(.*)$')
        re_payload_mtaskBegin = re.compile(
            r'id (\d+) predictStart (\d+) cpu (\d+)')
        re_payload_mtaskEnd = re.compile(r'id (\d+) predictCost (\d+)')

        re_arg1 = re.compile(r'VLPROF arg\s+(\S+)\+([0-9.]*)\s*')
        re_arg2 = re.compile(r'VLPROF arg\s+(\S+)\s+([0-9.]*)\s*$')
        re_stat = re.compile(r'VLPROF stat\s+(\S+)\s+([0-9.]+)')
        re_time = re.compile(r'rdtsc time = (\d+) ticks')
        re_proc_cpu = re.compile(r'VLPROFPROC processor\s*:\s*(\d+)\s*$')
        re_proc_dat = re.compile(r'VLPROFPROC ([a-z_ ]+)\s*:\s*(.*)$')
        cpu = None
        thread = None

        lastEvalBeginTick = None
        lastEvalLoopBeginTick = None

        for line in fh:
            recordMatch = re_record.match(line)
            if recordMatch:
                kind, tick, payload = recordMatch.groups()
                tick = int(tick)
                payload = payload.strip()
                if kind == "EVAL_BEGIN":
                    Evals[tick]['start'] = tick
                    lastEvalBeginTick = tick
                elif kind == "EVAL_END":
                    Evals[lastEvalBeginTick]['end'] = tick
                    lastEvalBeginTick = None
                elif kind == "EVAL_LOOP_BEGIN":
                    EvalLoops[tick]['start'] = tick
                    lastEvalLoopBeginTick = tick
                elif kind == "EVAL_LOOP_END":
                    EvalLoops[lastEvalLoopBeginTick]['end'] = tick
                    lastEvalLoopBeginTick = None
                elif kind == "MTASK_BEGIN":
                    mtask, predict_start, ecpu = re_payload_mtaskBegin.match(
                        payload).groups()
                    mtask = int(mtask)
                    predict_start = int(predict_start)
                    ecpu = int(ecpu)
                    Threads[thread][tick]['mtask'] = mtask
                    Threads[thread][tick]['predict_start'] = predict_start
                    Threads[thread][tick]['cpu'] = ecpu
                    if 'elapsed' not in Mtasks[mtask]:
                        Mtasks[mtask] = {'end': 0, 'elapsed': 0}
                    Mtasks[mtask]['begin'] = tick
                    Mtasks[mtask]['thread'] = thread
                    Mtasks[mtask]['predict_start'] = predict_start
                elif kind == "MTASK_END":
                    mtask, predict_cost = re_payload_mtaskEnd.match(
                        payload).groups()
                    mtask = int(mtask)
                    predict_cost = int(predict_cost)
                    begin = Mtasks[mtask]['begin']
                    Threads[thread][begin]['end'] = tick
                    Threads[thread][begin]['predict_cost'] = predict_cost
                    Mtasks[mtask]['elapsed'] += tick - begin
                    Mtasks[mtask]['predict_cost'] = predict_cost
                    Mtasks[mtask]['end'] = max(Mtasks[mtask]['end'], tick)
                elif Args.debug:
                    print("-Unknown execution trace record: %s" % line)
            elif re_thread.match(line):
                thread = int(re_thread.match(line).group(1))
            elif re.match(r'^VLPROF(THREAD|VERSION)', line):
                pass
            elif re_arg1.match(line):
                match = re_arg1.match(line)
                Global['args'][match.group(1)] = match.group(2)
            elif re_arg2.match(line):
                match = re_arg2.match(line)
                Global['args'][match.group(1)] = match.group(2)
            elif re_stat.match(line):
                match = re_stat.match(line)
                Global['stats'][match.group(1)] = match.group(2)
            elif re_proc_cpu.match(line):
                match = re_proc_cpu.match(line)
                cpu = int(match.group(1))
            elif cpu and re_proc_dat.match(line):
                match = re_proc_dat.match(line)
                term = match.group(1)
                value = match.group(2)
                term = re.sub(r'\s+$', '', term)
                term = re.sub(r'\s+', '_', term)
                value = re.sub(r'\s+$', '', value)
                Global['cpuinfo'][cpu][term] = value
            elif re.match(r'^#', line):
                pass
            elif Args.debug:
                print("-Unk: %s" % line)
            # TODO -- this is parsing text printed by a client.
            # Really, verilator proper should generate this
            # if it's useful...
            if re_time.match(line):
                Global['rdtsc_cycle_time'] = re_time.group(1)


def re_match_result(regexp, line, result_to):
    result_to = re.match(regexp, line)
    return result_to


######################################################################


def report():
    print("Verilator Gantt report")

    print("\nArgument settings:")
    for arg in sorted(Global['args'].keys()):
        plus = "+" if re.match(r'^\+', arg) else " "
        print("  %s%s%s" % (arg, plus, Global['args'][arg]))

    nthreads = int(Global['stats']['threads'])
    Global['cpus'] = {}
    for thread in Threads:
        # Make potentially multiple characters per column
        for start in Threads[thread]:
            if not Threads[thread][start]:
                continue
            cpu = Threads[thread][start]['cpu']
            elapsed = Threads[thread][start]['end'] - start
            if cpu not in Global['cpus']:
                Global['cpus'][cpu] = {'cpu_time': 0}
            Global['cpus'][cpu]['cpu_time'] += elapsed

    measured_mt_mtask_time = 0
    predict_mt_mtask_time = 0
    long_mtask_time = 0
    measured_last_end = 0
    predict_last_end = 0
    for mtask in Mtasks:
        measured_mt_mtask_time += Mtasks[mtask]['elapsed']
        predict_mt_mtask_time += Mtasks[mtask]['predict_cost']
        measured_last_end = max(measured_last_end, Mtasks[mtask]['end'])
        predict_last_end = max(
            predict_last_end,
            Mtasks[mtask]['predict_start'] + Mtasks[mtask]['predict_cost'])
        long_mtask_time = max(long_mtask_time, Mtasks[mtask]['elapsed'])
    Global['measured_last_end'] = measured_last_end
    Global['predict_last_end'] = predict_last_end

    # If we know cycle time in the same (rdtsc) units,
    # this will give us an actual utilization number,
    # (how effectively we keep the cores busy.)
    #
    # It also gives us a number we can compare against
    # serial mode, to estimate the overhead of data sharing,
    # which will show up in the total elapsed time. (Overhead
    # of synchronization and scheduling should not.)
    print("\nAnalysis:")
    print("  Total threads             = %d" % nthreads)
    print("  Total mtasks              = %d" % len(Mtasks))
    ncpus = max(len(Global['cpus']), 1)
    print("  Total cpus used           = %d" % ncpus)
    print("  Total yields              = %d" %
          int(Global['stats'].get('yields', 0)))
    print("  Total evals               = %d" % len(Evals))
    print("  Total eval loops          = %d" % len(EvalLoops))
    if Mtasks:
        print("  Total eval time           = %d rdtsc ticks" %
              Global['measured_last_end'])
        print("  Longest mtask time        = %d rdtsc ticks" % long_mtask_time)
        print("  All-thread mtask time     = %d rdtsc ticks" %
              measured_mt_mtask_time)
        long_efficiency = long_mtask_time / (Global.get(
            'measured_last_end', 1) or 1)
        print("  Longest-thread efficiency = %0.1f%%" %
              (long_efficiency * 100.0))
        mt_efficiency = measured_mt_mtask_time / (
            Global.get('measured_last_end', 1) * nthreads or 1)
        print("  All-thread efficiency     = %0.1f%%" %
              (mt_efficiency * 100.0))
        print("  All-thread speedup        = %0.1f" %
              (mt_efficiency * nthreads))
        if Global['rdtsc_cycle_time'] > 0:
            ut = measured_mt_mtask_time / Global['rdtsc_cycle_time']
            print("tot_mtask_cpu=" + measured_mt_mtask_time + " cyc=" +
                  Global['rdtsc_cycle_time'] + " ut=" + ut)

        predict_mt_efficiency = predict_mt_mtask_time / (
            Global.get('predict_last_end', 1) * nthreads or 1)
        print("\nPrediction (what Verilator used for scheduling):")
        print("  All-thread efficiency     = %0.1f%%" %
              (predict_mt_efficiency * 100.0))
        print("  All-thread speedup        = %0.1f" %
              (predict_mt_efficiency * nthreads))

        p2e_ratios = []
        min_p2e = 1000000
        min_mtask = None
        max_p2e = -1000000
        max_mtask = None

        for mtask in sorted(Mtasks.keys()):
            if Mtasks[mtask]['elapsed'] > 0:
                if Mtasks[mtask]['predict_cost'] == 0:
                    Mtasks[mtask]['predict_cost'] = 1  # don't log(0) below
                p2e_ratio = math.log(Mtasks[mtask]['predict_cost'] /
                                     Mtasks[mtask]['elapsed'])
                p2e_ratios.append(p2e_ratio)

                if p2e_ratio > max_p2e:
                    max_p2e = p2e_ratio
                    max_mtask = mtask
                if p2e_ratio < min_p2e:
                    min_p2e = p2e_ratio
                    min_mtask = mtask

        print("\nMTask statistics:")
        print("  min log(p2e) = %0.3f" % min_p2e, end="")
        print("  from mtask %d (predict %d," %
              (min_mtask, Mtasks[min_mtask]['predict_cost']),
              end="")
        print(" elapsed %d)" % Mtasks[min_mtask]['elapsed'])
        print("  max log(p2e) = %0.3f" % max_p2e, end="")
        print("  from mtask %d (predict %d," %
              (max_mtask, Mtasks[max_mtask]['predict_cost']),
              end="")
        print(" elapsed %d)" % Mtasks[max_mtask]['elapsed'])

        stddev = statistics.pstdev(p2e_ratios)
        mean = statistics.mean(p2e_ratios)
        print("  mean = %0.3f" % mean)
        print("  stddev = %0.3f" % stddev)
        print("  e ^ stddev = %0.3f" % math.exp(stddev))

    report_cpus()

    if nthreads > ncpus:
        print()
        print("%%Warning: There were fewer CPUs (%d) then threads (%d)." %
              (ncpus, nthreads))
        print("        : See docs on use of numactl.")
    else:
        if 'cpu_socket_cores_warning' in Global:
            print()
            print(
                "%Warning: Multiple threads scheduled on same hyperthreaded core."
            )
            print("        : See docs on use of numactl.")
        if 'cpu_sockets_warning' in Global:
            print()
            print("%Warning: Threads scheduled on multiple sockets.")
            print("        : See docs on use of numactl.")
    print()


def report_cpus():
    print("\nCPUs:")

    Global['cpu_sockets'] = collections.defaultdict(lambda: 0)
    Global['cpu_socket_cores'] = collections.defaultdict(lambda: 0)

    for cpu in sorted(Global['cpus'].keys()):
        print("  cpu %d: " % cpu, end='')
        print("cpu_time=%d" % Global['cpus'][cpu]['cpu_time'], end='')

        socket = None
        if cpu in Global['cpuinfo']:
            cpuinfo = Global['cpuinfo'][cpu]
            if 'physical_id' in cpuinfo and 'core_id' in cpuinfo:
                socket = int(cpuinfo['physical_id'])
                Global['cpu_sockets'][socket] += 1
                print(" socket=%d" % socket, end='')

                core = int(cpuinfo['core_id'])
                Global['cpu_socket_cores'][str(socket) + "__" + str(core)] += 1
                print(" core=%d" % core, end='')

            if 'model_name' in cpuinfo:
                model = cpuinfo['model_name']
                print("  %s" % model, end='')
        print()

    if len(Global['cpu_sockets']) > 1:
        Global['cpu_sockets_warning'] = True
        for scn in Global['cpu_socket_cores'].values():
            if scn > 1:
                Global['cpu_socket_cores_warning'] = True


######################################################################


def write_vcd(filename):
    print("Writing %s" % filename)
    with open(filename, "w") as fh:
        vcd = {
            'values':
            collections.defaultdict(lambda: {}),  # {<time>}{<code>} = value
            'sigs': {
                'predicted_threads': {},
                'measured_threads': {},
                'cpus': {},
                'evals': {},
                'mtasks': {},
                'Stats': {}
            }  # {<module>}{<sig}} = code
        }
        code = 0

        parallelism = {
            'measured': collections.defaultdict(lambda: 0),
            'predicted': collections.defaultdict(lambda: 0)
        }
        parallelism['measured'][0] = 0
        parallelism['predicted'][0] = 0

        # Measured graph
        for thread in sorted(Threads.keys()):
            sig = "thread%d_mtask" % thread
            if sig not in vcd['sigs']['measured_threads']:
                vcd['sigs']['measured_threads'][sig] = code
                code += 1
            mcode = vcd['sigs']['measured_threads'][sig]

            for start in sorted(Threads[thread]):
                mtask = Threads[thread][start]['mtask']
                end = Threads[thread][start]['end']
                cpu = Threads[thread][start]['cpu']
                vcd['values'][start][mcode] = mtask
                vcd['values'][end][mcode] = None
                parallelism['measured'][start] += 1
                parallelism['measured'][end] -= 1

                sig = "cpu%d_thread" % cpu
                if sig not in vcd['sigs']['cpus']:
                    vcd['sigs']['cpus'][sig] = code
                    code += 1
                ccode = vcd['sigs']['cpus'][sig]
                vcd['values'][start][ccode] = thread
                vcd['values'][end][ccode] = None

                sig = "mtask%d_cpu" % mtask
                if sig not in vcd['sigs']['mtasks']:
                    vcd['sigs']['mtasks'][sig] = code
                    code += 1
                ccode = vcd['sigs']['mtasks'][sig]
                vcd['values'][start][ccode] = cpu
                vcd['values'][end][ccode] = None

        # Eval graph
        vcd['sigs']['evals']["eval"] = code
        elcode = code
        code += 1
        n = 0
        for eval_start in Evals:
            eval_end = Evals[eval_start]['end']
            n += 1
            vcd['values'][eval_start][elcode] = n
            vcd['values'][eval_end][elcode] = None

        # Eval_loop graph
        vcd['sigs']['evals']["eval_loop"] = code
        elcode = code
        code += 1
        n = 0
        for eval_start in EvalLoops:
            eval_end = EvalLoops[eval_start]['end']
            n += 1
            vcd['values'][eval_start][elcode] = n
            vcd['values'][eval_end][elcode] = None

        if Mtasks:
            # Predicted graph
            for eval_start in EvalLoops:
                eval_end = EvalLoops[eval_start]['end']
                # Compute scale so predicted graph is of same width as eval
                measured_scaling = (eval_end -
                                    eval_start) / Global['predict_last_end']
                # Predict mtasks that fill the time the eval occupied
                for mtask in Mtasks:
                    thread = Mtasks[mtask]['thread']
                    pred_scaled_start = eval_start + int(
                        Mtasks[mtask]['predict_start'] * measured_scaling)
                    pred_scaled_end = eval_start + int(
                        (Mtasks[mtask]['predict_start'] +
                         Mtasks[mtask]['predict_cost']) * measured_scaling)
                    if pred_scaled_start == pred_scaled_end:
                        continue

                    sig = "predicted_thread%d_mtask" % thread
                    if sig not in vcd['sigs']['predicted_threads']:
                        vcd['sigs']['predicted_threads'][sig] = code
                        code += 1
                    mcode = vcd['sigs']['predicted_threads'][sig]

                    vcd['values'][pred_scaled_start][mcode] = mtask
                    vcd['values'][pred_scaled_end][mcode] = None

                    parallelism['predicted'][pred_scaled_start] += 1
                    parallelism['predicted'][pred_scaled_end] -= 1

            # Parallelism graph
            for measpred in ('measured', 'predicted'):
                vcd['sigs']['Stats']["%s_parallelism" % measpred] = code
                pcode = code
                code += 1
                value = 0
                for time in sorted(parallelism[measpred].keys()):
                    value += parallelism[measpred][time]
                    vcd['values'][time][pcode] = value

        # Create output file
        fh.write("$version Generated by verilator_gantt $end\n")
        fh.write("$timescale 1ns $end\n")
        fh.write("\n")

        all_codes = {}
        fh.write(" $scope module gantt $end\n")
        for module in sorted(vcd['sigs'].keys()):
            fh.write("  $scope module %s $end\n" % module)
            for sig in sorted(vcd['sigs'][module].keys()):
                code = vcd['sigs'][module][sig]
                fh.write("   $var wire 32 v%x %s [31:0] $end\n" % (code, sig))
                all_codes[code] = 1
            fh.write("  $upscope $end\n")
        fh.write(" $upscope $end\n")
        fh.write("$enddefinitions $end\n")
        fh.write("\n")

        first = True
        for time in sorted(vcd['values']):
            if first:
                first = False
                # Start with Z for any signals without time zero data
                for code in sorted(all_codes.keys()):
                    if code not in vcd['values'][time]:
                        vcd['values'][time][code] = None
            fh.write("#%d\n" % time)
            for code in sorted(vcd['values'][time].keys()):
                value = vcd['values'][time][code]
                if value is None:
                    fh.write("bz v%x\n" % code)
                else:
                    fh.write("b%s v%x\n" % (format(value, 'b'), code))


######################################################################

parser = argparse.ArgumentParser(
    allow_abbrev=False,
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="""Create Gantt chart of multi-threaded execution""",
    epilog=
    """Verilator_gantt creates a visual representation to help analyze Verilator
#xmultithreaded simulation performance, by showing when each macro-task
#xstarts and ends, and showing when each thread is busy or idle.

For documentation see
https://verilator.org/guide/latest/exe_verilator_gantt.html

Copyright 2018-2022 by Wilson Snyder. This program is free software; you
can redistribute it and/or modify it under the terms of either the GNU
Lesser General Public License Version 3 or the Perl Artistic License
Version 2.0.

SPDX-License-Identifier: LGPL-3.0-only OR Artistic-2.0""")

parser.add_argument('--debug', action='store_true', help='enable debug')
parser.add_argument('--no-vcd',
                    help='disable creating vcd',
                    action='store_true')
parser.add_argument('--vcd',
                    help='filename for vcd outpue',
                    default='profile_exec.vcd')
parser.add_argument('filename',
                    help='input profile_exec.dat filename to process',
                    default='profile_exec.dat')

Args = parser.parse_args()

process(Args.filename)
if not Args.no_vcd:
    write_vcd(Args.vcd)

######################################################################
# Local Variables:
# compile-command: "./verilator_gantt ../test_regress/t/t_gantt_io.dat"
# End: