#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Merge two or more bag files into one."""

from __future__ import print_function

import argparse

from rosbag import Bag, Compression

from cras_bag_tools import TqdmBag


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Merge two or more bag files.')
    parser.add_argument('outbag', help='Output bag file.')
    parser.add_argument('inbag', nargs='+', help='Input bag file(s).')
    parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Verbose output')
    parser.add_argument('-c', '--compress', action="store_true", default=False,
                        help='Compress output bag with LZ4 compression')
    parser.add_argument('-b', '--bz2', action="store_true", default=False,
                        help='Compress output bag with BZ2 compression')

    args = parser.parse_args()

    if args.verbose:
        print("Writing bag file: {}.".format(args.outbag))

    compression = Compression.NONE
    if args.compress:
        compression = Compression.LZ4
    elif args.bz2:
        compression = Compression.BZ2
    if args.verbose:
        if compression == Compression.NONE:
            print("Not compressing output bag (use -c or -b to add compression).")
        else:
            print("Compressing output bag with {} compression.".format(compression))

    total_message_count = 0
    i = 0
    with Bag(args.outbag, 'w', compression=compression) as out:
        for in_file in args.inbag:
            i += 1
            if args.verbose:
                print("{}/{}: Reading bag file: {}.".format(i, len(args.inbag), in_file))
            message_count = 0
            with TqdmBag(in_file, 'r') as in_bag:
                # passing connection header is important to retain topic latching status
                for topic, msg, t, conn_header in in_bag.read_messages(raw=True, return_connection_header=True):
                    out.write(topic, msg, t, raw=True, connection_header=conn_header)
                    message_count += 1
            total_message_count += message_count
            if args.verbose:
                print("{}/{}: Read {} messages.".format(i, len(args.inbag), message_count))
    if args.verbose:
        print("Total: Read {} messages.".format(total_message_count))
