#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Print total cumulative serialized message size per topic."""

from __future__ import print_function

import argparse
import sys

from cras_bag_tools.tqdm_bag import TqdmBag


def _sizeof_fmt(num, suffix='B'):
    """Return a human-readable string representing a byte size in the most suitable unit.

    :param int|float num: The size to express as a string.
    :param str suffix: The actual unit without any kilo- etc. prefixes.
    :return: The human-readable string.
    """
    units = ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi')
    unit = units[-1]
    for u in units:
        if abs(num) < 1024.0:
            unit = u
            break
        num /= 1024.0
    num_format = "%6.1f" if abs(num) < 1024.0 else "%.1f"
    num_str = num_format % num
    return "%s %s%s" % (num_str, unit, suffix)


parser = argparse.ArgumentParser(description='Print total cumulative serialized message size per topic.')
parser.add_argument('bag', help='Bag file.')
parser.add_argument('-c', '--csv', action="store_true", default=False,
                    help='Output as CSV.')
parser.add_argument('-a', '--sort-alphabetical', action="store_true", default=False,
                    help='Sort by topic names (default is by topic sizes).')

args = parser.parse_args()

# We have to read all messages from the bag, there is probably no way around it
topic_size_dict = {}
num_msgs_dict = {}
for topic, msg, time in TqdmBag(sys.argv[1], 'r').read_messages(raw=True):
    topic_size_dict[topic] = topic_size_dict.get(topic, 0) + len(msg[1])
    num_msgs_dict[topic] = num_msgs_dict.get(topic, 0) + 1

# Convert the dictionary to a list of iteritems so that we can define ordering
topic_size = list(topic_size_dict.items())
sort_key_fn = (lambda x: x[0]) if args.sort_alphabetical else (lambda x: x[1])
topic_size.sort(key=sort_key_fn)

if args.csv:
    print("topic,size,messages,avg_size")
else:
    longest_topic = max(map(len, topic_size_dict.keys()))

for topic, size in topic_size:
    num_msgs = num_msgs_dict[topic]
    avg_size = float(size) / num_msgs_dict[topic]
    if args.csv:
        print("%s,%i,%i,%i" % (topic, size, num_msgs, int(avg_size)))
    else:
        t = topic + " " + ("." * (longest_topic - len(topic)))
        print("%s %s, %7i messages, avg. message size %s" % (
            t, _sizeof_fmt(size).ljust(10), num_msgs, _sizeof_fmt(avg_size)))
