#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Extract pointclouds to PCD files.

<pre>
Usage: extract_pcd [-h] [--verbose] bag_file output_dir [pcl_topics [pcl_topics ...]]
</pre>

The parameters are:

* `bag_file`: The bag to read.
* `output_dir`: Directory where all generated files should be stored.
* `pcl_topics`: Zero or more topics to convert. If zero, all PointCloud2 topics are converted.
* `-h`: Shows help.
* `--verbose`: Print various details during execution.

Example command:

<pre>
# Convert all pointcloud topics in the bag to a series of PCD files in folder pcd/
rosrun cras_bag_tools extract_pcd spot_2022-10-27-10-35-46.pcl.bag pcd

# Convert topic /points to PCD files:
rosrun cras_bag_tools extract_pcd spot_2022-10-27-10-35-46.pcl.bag . /points
</pre>
"""

from __future__ import division
from __future__ import print_function

import argparse
import os
import re
import sys
import zipfile
from collections import defaultdict

import rosbag
import rospy
from cras_bag_tools import convert_PointCloud2_to_pcd, TqdmBag
from sensor_msgs.msg import PointCloud2


class Parser(argparse.ArgumentParser):

    def __init__(self):
        description = "Extract pointclouds from a ROS bag. If you pass no topics, all pointcloud" \
                      "topics will be exported."

        super(Parser, self).__init__(description=description)

        self.add_argument("bag_file", help="Input bagfile.")
        self.add_argument("output_dir", help="Output directory.")
        self.add_argument("-n", "--name-template", default="{topic}-{secs}.{nsecs:09d}.{format}",
                          help="Template for naming the output files. This is a Python format string." +
                          "Available variables are bag, topic, secs, nsecs, msg_num, t, format.")
        self.add_argument("--zip-name-template", default="{topic}.{format}",
                          help="Template for naming the output ZIP files. This is a Python format string." +
                          "Available variables are bag, topic, secs, nsecs, msg_num, t, format.")
        self.add_argument("-z", "--zip", action="store_true", help="Store the output PCD files directly into a ZIP")
        self.add_argument("-c", "--zip-compression", default=6, help="ZIP compression level [0-9]")
        self.add_argument("-v", "--verbose", action="store_true", help="Enable debug prints.")
        self.add_argument("-q", "--no-progress", action="store_false", dest="progress", default=True,
                          help="Disable progress bars.")
        self.add_argument('pcl_topics', nargs="*", help="The pointcloud topics to extract.")


def get_filename(bag, topic, name_pattern="{topic}-{secs}.{nsecs:09d}.{format}", msg=None, t=None, msg_num=None,
                 format="pcd"):
    if not name_pattern.endswith(".{format}"):
        name_pattern += ".{format}"

    if msg is not None:
        stamp = msg.header.stamp
    else:
        # if msg is None, we want to get just a template that is not dependent on a particular message data, so we
        # escape the substitutions by doubling the curly braces
        name_pattern = re.sub(r'{secs[^}]*}', r'{\g<0>}', name_pattern)
        name_pattern = re.sub(r'{nsecs[^}]*}', r'{\g<0>}', name_pattern)
        name_pattern = re.sub(r'{msg_num[^}]*}', r'{\g<0>}', name_pattern)
        stamp = rospy.Time(0)

    filename = name_pattern.format(
        bag=bag, topic=topic.replace("/", "_").lstrip("_"), format=format,
        secs=stamp.secs, nsecs=stamp.nsecs, msg_num=msg_num, t=t)

    return filename


__topic_info = None
"""Cached topic info for the bag file."""


def get_topic_info(bag):
    """Lazy getter of topic info (which might take time to read).

    :param rosbag.Bag bag: The bag file to get topic info from.
    :return: The topic info.
    :rtype: tuple
    """
    global __topic_info
    if __topic_info is None:
        __topic_info = bag.get_type_and_topic_info().topics
    return __topic_info


def main():
    # For throttled loggers to work.
    rospy.rostime.set_rostime_initialized(True)

    parser = Parser()

    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Parse topic arguments into a per-topic dict

    pcl_topics = list(args.pcl_topics)
    name_template = str(args.name_template)
    basename = os.path.splitext(os.path.basename(args.bag_file))[0]

    bag_class = TqdmBag if args.progress else rosbag.Bag
    with bag_class(args.bag_file, "r") as bag:
        # If no topics were specified, select all pcl topics
        if len(pcl_topics) == 0:
            topic_info = get_topic_info(bag)
            pcl_topics = [t for t in topic_info if topic_info[t].msg_type == PointCloud2._type]

        if args.verbose:
            num_pcls = bag.get_message_count(pcl_topics)
            print("Extracting %i messages from %s into %s on the following topics:" % (
                num_pcls, args.bag_file, args.output_dir))
            for t in pcl_topics:
                print("- " + t)

        # Read and convert the topics

        msg_nums = defaultdict(lambda: -1)
        reported_outputs = set()
        for topic, msg, t in bag.read_messages(topics=pcl_topics):
            if topic not in pcl_topics:  # Should not happen, but to be sure
                continue
            msg_nums[topic] += 1

            try:
                out_filename = get_filename(basename, topic, name_template, msg, t, msg_nums[topic])
                out_file = os.path.join(args.output_dir, out_filename)
                if args.verbose and topic not in reported_outputs:
                    reported_outputs.add(topic)
                    out_template = os.path.join(args.output_dir, get_filename(basename, topic, name_template))
                    print("Exporting topic %s as pointclouds %s." % (topic, out_template))

                pcd_data = convert_PointCloud2_to_pcd(msg)

                if args.zip:
                    zip_filename = get_filename(basename, topic, args.zip_name_template, msg, t, msg_nums[topic],
                                                format="zip")
                    zip_file = os.path.join(args.output_dir, zip_filename)
                    with zipfile.ZipFile(zip_file, "a", zipfile.ZIP_DEFLATED,
                                         compresslevel=int(args.zip_compression)) as zip_f:
                        zip_f.writestr(out_filename, pcd_data)
                else:
                    with open(out_file, 'wb') as f:
                        f.write(pcd_data)

            except IOError as e:
                print("Error writing pointcloud from topic %s: %s" % (topic, str(e)), sys.stderr)


if __name__ == '__main__':
    main()
