# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Implementations of common message filters."""

from __future__ import absolute_import, division, print_function

import copy
import csv
import math
import os.path
import re
import sys

import matplotlib.cm as cmap
import numpy as np
import skimage.draw
import yaml

from bisect import bisect_left
from collections import deque
from enum import Enum
from functools import reduce
from lxml import etree
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union
from numpy.linalg import inv

import genpy
import rospy
from angles import normalize_angle, normalize_angle_positive
from camera_calibration_parsers import readCalibration
from cras.distortion_models import PLUMB_BOB, RATIONAL_POLYNOMIAL, EQUIDISTANT
from cras.geometry_utils import quat_msg_from_rpy
from cras.image_encodings import isColor, isMono, isBayer, isDepth, bitDepth, numChannels, MONO8, RGB8, BGR8,\
    TYPE_16UC1, TYPE_32FC1, YUV422
from cras.log_utils import rosconsole_notifyLoggerLevelsChanged, rosconsole_set_logger_level, RosconsoleLevel
from cras.message_utils import raw_to_msg, msg_to_raw
from cras.string_utils import to_str, STRING_TYPE, to_valid_ros_name
import cv2  # Workaround for https://github.com/opencv/opencv/issues/14884 on Jetsons.
from cv_bridge import CvBridge, CvBridgeError
from dynamic_reconfigure.encoding import decode_config
from dynamic_reconfigure.msg import Config
from geometry_msgs.msg import Quaternion, Transform, TransformStamped, Twist, TwistStamped, Vector3
from gps_common.msg import GPSFix
from image_transport_codecs import decode, encode
from image_transport_codecs.compressed_depth_codec import has_rvl
from image_transport_codecs.parse_compressed_format import guess_any_compressed_image_transport_format
from kdl_parser_py.urdf import treeFromUrdfModel
from nav_msgs.msg import Odometry
from ros_numpy import msgify, numpify
from sensor_msgs.msg import CameraInfo, CompressedImage, Image, Imu, JointState, MagneticField, NavSatFix
from std_msgs.msg import Float64MultiArray, Header, String
from tf2_msgs.msg import TFMessage
from tf2_py import BufferCore, TransformException
from urdf_parser_py import urdf, xml_reflection
from vision_msgs.msg import Detection2DArray, Detection2D, ObjectHypothesisWithPose

from .bag_utils import bag_msg_type_to_standard_type
from .message_filter import ConnectionHeader, DeserializedMessageData, DeserializedMessageFilter, \
    DeserializedMessageFilterWithTF, MessageTags, NoMessageFilter, RawMessageFilter, Tags, TopicSet, \
    cleanup_requeue_tags, deserialize_header, normalize_filter_result, num_used_requeues, \
    tags_for_changed_msg, tags_for_generated_msg, tags_for_requeuing_msg
from .message_filters_base import ImageTransportFilter, MessageToCSVExporterBase, MessageToYAMLExporterBase

STR = STRING_TYPE
FilteredImage = Tuple[STR, STR, genpy.Message, genpy.Message, STR, rospy.Time, ConnectionHeader, Tags]
FilteredImageOrAnyMsg = Union[FilteredImage, DeserializedMessageData]


def urdf_error(message):
    if "selfCollide" in message:
        return
    print(message, file=sys.stderr)


xml_reflection.core.on_error = urdf_error


def items_to_str(items, sep='='):
    return '{' + ', '.join('%s%s%s' % (
        k, sep, str(v) if not isinstance(v, dict) else dict_to_str(v, sep)) for k, v in items) + "}"


def dict_to_str(d, sep='='):
    return items_to_str(d.items(), sep)


def create_connection_header(topic, msg_type, latch=False):
    # type: (STRING_TYPE, Type[genpy.Message], bool) -> Dict[STRING_TYPE, STRING_TYPE]
    header = {
        "callerid": "/bag_filter",
        "topic": topic,
        "message_definition": msg_type._full_text,  # noqa
        "type": msg_type._type,  # noqa
        "md5sum": msg_type._md5sum,  # noqa
    }
    if latch:
        header["latching"] = "1"
    return header


class SetFields(DeserializedMessageFilter):
    """Change values of some fields of a message (pass the fields to change as kwargs)."""

    def __init__(self, include_topics=None, exclude_topics=None, include_types=None, exclude_types=None,
                 min_stamp=None, max_stamp=None, include_time_ranges=None, exclude_time_ranges=None,
                 include_tags=None, exclude_tags=None, **kwargs):
        """
        :param list include_topics: If nonempty, the filter will only work on these topics.
        :param list exclude_topics: If nonempty, the filter will skip these topics (but pass them further).
        :param list include_types: If nonempty, the filter will only work on these message types.
        :param list exclude_types: If nonempty, the filter will skip these message types (but pass them further).
        :param rospy.Time min_stamp: If set, the filter will only work on messages after this timestamp.
        :param rospy.Time max_stamp: If set, the filter will only work on messages before this timestamp.
        :param include_time_ranges: Time ranges that specify which regions of the bag should be processed.
                                    List of pairs (start, end_or_duration) or a TimeRanges object.
        :type include_time_ranges: list or TimeRanges
        :param exclude_time_ranges: Time ranges that specify which regions of the bag should be skipped.
                                    List of pairs (start, end_or_duration) or a TimeRanges object.
        :type exclude_time_ranges: list or TimeRanges
        :param list include_tags: If nonempty, the filter will only work on messages with these tags.
        :param list exclude_tags: If nonempty, the filter will skip messages with these tags.
        :param dict kwargs: The fields to set. Keys are field name, values are the new values to set.
        """
        super(SetFields, self).__init__(
            include_topics, exclude_topics, include_types, exclude_types, min_stamp, max_stamp,
            include_time_ranges, exclude_time_ranges, include_tags, exclude_tags)
        self.field_values = kwargs

    def filter(self, topic, msg, stamp, header, tags):
        for k, v in self.field_values.items():
            if k not in msg.__slots__:
                continue
            setattr(msg, k, v)
            tags.add(MessageTags.CHANGED)
        return topic, msg, stamp, header, tags

    def _str_params(self):
        parts = []
        params = dict_to_str(self.field_values)
        if len(params) > 0:
            parts.append(params)
        parent_params = super(SetFields, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class FixHeader(DeserializedMessageFilter):
    """Change the header of a message."""

    def __init__(self, frame_id="", frame_id_prefix="", frame_id_suffix="", stamp_from_receive_time=False,
                 stamp_offset=0.0, receive_stamp_offset=0.0, receive_stamp_from_header=False, *args, **kwargs):
        """
        :param str frame_id: If nonempty, frame_id will be set to this value.
        :param str frame_id_prefix: If nonempty, header will be prefixed with this value.
        :param str frame_id_suffix: If nonempty, header will be suffixed with this value.
        :param bool stamp_from_receive_time: If true, set stamp from the connection header receive time.
        :param stamp_offset: If nonzero, offset the stamp by this duration (seconds).
        :type stamp_offset: float or rospy.Duration
        :param receive_stamp_offset: If nonzero, offset the receive stamp by this duration (seconds).
        :type receive_stamp_offset: float or rospy.Duration
        :param bool receive_stamp_from_header: If true, set the receive stamp in connection header from message stamp.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(FixHeader, self).__init__(*args, **kwargs)
        self.frame_id = frame_id
        self.frame_id_prefix = frame_id_prefix
        self.frame_id_suffix = frame_id_suffix
        self.stamp_from_receive_time = stamp_from_receive_time
        self.stamp_offset = rospy.Duration(stamp_offset)
        self.receive_stamp_from_header = receive_stamp_from_header
        self.receive_stamp_offset = rospy.Duration(receive_stamp_offset)

    def filter(self, topic, msg, stamp, header, tags):
        if len(msg.__slots__) > 0 and msg.__slots__[0] == 'header':
            self.fix_header(msg.header, stamp)
            stamp = self.fix_receive_stamp(msg.header, stamp)
            tags.add(MessageTags.CHANGED)
        # Support for TF, Path and similar array-only messages
        elif len(msg.__slots__) == 1 and msg._get_types()[0].endswith("[]"):
            receive_stamp_updated = False
            for m in getattr(msg, msg.__slots__[0]):
                if len(m.__slots__) > 0 and m.__slots__[0] == 'header':
                    self.fix_header(m.header, stamp)
                    if not receive_stamp_updated:
                        stamp = self.fix_receive_stamp(m.header, stamp)
                        receive_stamp_updated = True
                    tags.add(MessageTags.CHANGED)
            if not receive_stamp_updated:
                stamp += self.receive_stamp_offset

        return topic, msg, stamp, header, tags

    def fix_header(self, header, stamp):
        if len(self.frame_id) > 0:
            header.frame_id = self.frame_id
        else:
            if len(self.frame_id_prefix) > 0:
                header.frame_id = self.frame_id_prefix + header.frame_id
            if len(self.frame_id_suffix) > 0:
                header.frame_id += self.frame_id_suffix

        if self.stamp_from_receive_time:
            header.stamp = stamp
        header.stamp += self.stamp_offset

    def fix_receive_stamp(self, header, stamp):
        if self.receive_stamp_from_header:
            stamp = header.stamp
        stamp += self.receive_stamp_offset
        return stamp

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument(
            '--stamp-offset', nargs='?', default=0.0, type=float, metavar="OFFSET",
            help="Offset header.stamp of messages by the given value [s]")
        parser.add_argument(
            '--receive-stamp-offset', nargs='?', default=0.0, type=float, metavar="OFFSET",
            help="Offset receive timestamp of messages by the given value [s].")

    @staticmethod
    def process_cli_args(filters, args):
        if args.stamp_offset != 0.0 or args.receive_stamp_offset != 0.0:
            filters.append(FixHeader(stamp_offset=args.stamp_offset, receive_stamp_offset=args.receive_stamp_offset))

    def _str_params(self):
        parts = []
        if len(self.frame_id) > 0:
            parts.append('frame_id=' + self.frame_id)
        if len(self.frame_id_prefix) > 0:
            parts.append('frame_id_prefix=' + self.frame_id_prefix)
        if len(self.frame_id_suffix) > 0:
            parts.append('frame_id_suffix=' + self.frame_id_suffix)
        if self.stamp_from_receive_time:
            parts.append('stamp_from_receive_time')
        if self.stamp_offset != rospy.Duration(0, 0):
            parts.append('stamp_offset=' + to_str(self.stamp_offset).rstrip('0'))
        if self.receive_stamp_from_header:
            parts.append('receive_stamp_from_header')
        if self.receive_stamp_offset != rospy.Duration(0, 0):
            parts.append('receive_stamp_offset=' + to_str(self.receive_stamp_offset).rstrip('0'))
        parent_params = super(FixHeader, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class Deduplicate(RawMessageFilter):
    """Discard all messages except each first changed."""

    def __init__(self, ignore_seq=False, ignore_stamp=False, ignore_stamp_difference=None, per_frame_id=False,
                 max_ignored_duration=None, *args, **kwargs):
        """
        :param bool ignore_seq: If True, differing header.seq will not make a difference.
        :param bool ignore_stamp: If True, differing header.stamp will not make a difference.
        :param float ignore_stamp_difference: If set, stamps differing by up to this duration are considered equal.
        :param bool per_frame_id: If True, messages will be clustered by header.frame_id and comparisons will only be
                                  made between messages with equal frame_id.
        :param float max_ignored_duration: If set, a duplicate will pass if its stamp is further from the last passed
                                           message than the given duration (in seconds).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(Deduplicate, self).__init__(*args, **kwargs)
        self._ignore_seq = ignore_seq
        self._ignore_stamp = ignore_stamp
        self._ignore_stamp_difference = \
            rospy.Duration(ignore_stamp_difference) if ignore_stamp_difference is not None else None
        self._per_frame_id = per_frame_id
        self._max_ignored_duration = rospy.Duration(max_ignored_duration) if max_ignored_duration is not None else None
        self._last_msgs = {}

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        has_header = pytype.__slots__[0] == 'header'
        key = topic
        msg_header = None
        if has_header and self._per_frame_id:
            msg_header = deserialize_header(data, pytype)
            key = "%s@%s" % (topic, msg_header.frame_id)

        if key not in self._last_msgs:
            self._last_msgs[key] = data, stamp
            return topic, datatype, data, md5sum, pytype, stamp, header, tags

        last_msg, last_msg_stamp = self._last_msgs[key]

        seq_ok = True
        stamp_ok = True
        stamp_diff_ok = True
        compare_idx = 0

        if has_header:
            seq_ok = self._ignore_seq or data[0:4] == last_msg[0:4]
            stamp_ok = self._ignore_stamp or data[4:12] == last_msg[4:12]
            if not self._ignore_stamp and self._ignore_stamp_difference is not None and not stamp_ok:
                if msg_header is None:
                    msg_header = deserialize_header(data, pytype)
                last_msg_header = deserialize_header(last_msg, pytype)
                diff = abs(msg_header.stamp - last_msg_header.stamp)
                if diff < self._ignore_stamp_difference:
                    stamp_ok = True

            compare_idx = 12
        if self._max_ignored_duration is not None:
            stamp_diff_ok = stamp - last_msg_stamp < self._max_ignored_duration

        if seq_ok and stamp_ok and stamp_diff_ok and data[compare_idx:] == last_msg[compare_idx:]:
            return None

        self._last_msgs[key] = data, stamp

        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def _str_params(self):
        parts = []
        parts.append('ignore_seq=%r' % (self._ignore_seq,))
        parts.append('ignore_stamp=%r' % (self._ignore_stamp,))
        if self._ignore_stamp_difference is not None:
            parts.append('ignore_stamp_difference=%s' % (to_str(self._ignore_stamp_difference),))
        parts.append('per_frame_id=%r' % (self._per_frame_id,))
        parts.append('max_ignored_duration=%r' % (self._max_ignored_duration,))
        parent_params = super(Deduplicate, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class DeduplicateJointStates(DeserializedMessageFilter):
    """Discard all messages except each first changed."""

    def __init__(self, max_ignored_duration=None, include_topics=None, ignore_stamp=False, ignore_stamp_difference=None,
                 *args, **kwargs):
        """
        :param float max_ignored_duration: If set, a duplicate will pass if its stamp is further from the last passed
                                           joint state than the given duration (in seconds).
        :param list include_topics: Topics to work on (defaults to 'joint_states').
        :param bool ignore_stamp: If True, differing header.stamp will not make a difference.
        :param float ignore_stamp_difference: If set, stamps differing by up to this duration are considered equal.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(DeduplicateJointStates, self).__init__(
            include_topics=['joint_states'] if include_topics is None else include_topics,
            include_types=[JointState._type], *args, **kwargs)  # noqa
        self._ignore_stamp = ignore_stamp
        self._ignore_stamp_difference = \
            rospy.Duration(ignore_stamp_difference) if ignore_stamp_difference is not None else None
        self._max_ignored_duration = rospy.Duration(max_ignored_duration) if max_ignored_duration is not None else None

        self._last_states = {}

    def filter(self, topic, msg, stamp, header, tags):
        if len(msg.name) == 0:
            return topic, msg, stamp, header, tags

        for i in reversed(range(len(msg.name))):
            name = msg.name[i]

            if name not in self._last_states:
                self._last_states[name] = self._get_state(msg, i, stamp)
                continue

            last_state = self._last_states[name]

            stamp_ok = msg.header.stamp == last_state[0]
            if not stamp_ok and not self._ignore_stamp and self._ignore_stamp_difference is not None:
                diff = abs(msg.header.stamp - last_state[0])
                if diff < self._ignore_stamp_difference:
                    stamp_ok = True

            stamp_diff_ok = True
            if self._max_ignored_duration is not None:
                stamp_diff_ok = stamp - last_state[5] < self._max_ignored_duration

            state = self._get_state(msg, i, stamp)
            if stamp_ok and stamp_diff_ok:
                if np.allclose(state[2:5], last_state[2:5], equal_nan=True):
                    if isinstance(msg.name, tuple):
                        msg.name = list(msg.name)
                    del msg.name[i]
                    if len(msg.position) > i:
                        if isinstance(msg.position, tuple):
                            msg.position = list(msg.position)
                        del msg.position[i]
                    if len(msg.velocity) > i:
                        if isinstance(msg.velocity, tuple):
                            msg.velocity = list(msg.velocity)
                        del msg.velocity[i]
                    if len(msg.effort) > i:
                        if isinstance(msg.effort, tuple):
                            msg.effort = list(msg.effort)
                        del msg.effort[i]
            else:
                self._last_states[name] = state

        if len(msg.name) == 0:
            return None

        return topic, msg, stamp, header, tags

    def _get_state(self, msg, i, stamp):
        return (
            msg.header.stamp,
            msg.name[i],
            msg.position[i] if len(msg.position) > i else None,
            msg.velocity[i] if len(msg.velocity) > i else None,
            msg.effort[i] if len(msg.effort) > i else None,
            stamp,
        )

    def _str_params(self):
        parts = []
        parts.append('ignore_stamp=%r' % (self._ignore_stamp,))
        if self._ignore_stamp_difference is not None:
            parts.append('ignore_stamp_difference=%s' % (to_str(self._ignore_stamp_difference),))
        parts.append('max_ignored_duration=%r' % (self._max_ignored_duration,))
        parent_params = super(DeduplicateJointStates, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class DeduplicateTF(DeserializedMessageFilter):
    """Discard all messages except each first changed."""

    def __init__(self, max_ignored_duration=None, cache_size=1, include_topics=None, *args, **kwargs):
        """
        :param float max_ignored_duration: If set, a duplicate will pass if its stamp is further from the last passed
                                           message than the given duration (in seconds).
        :param int cache_size: Number of previous messages to remember for each TF
                               (duplicates are searched in this cache).
        :param list include_topics: Topics to work on (defaults to standard TF topics).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(DeduplicateTF, self).__init__(
            include_topics=['tf', 'tf_static'] if include_topics is None else include_topics,
            include_types=[TFMessage._type], *args, **kwargs)  # noqa
        self._max_ignored_duration = rospy.Duration(max_ignored_duration) if max_ignored_duration is not None else None
        self._cache_size = max(1, cache_size)
        self._last_msgs = {}

    def filter(self, topic, msg, stamp, header, tags):
        tfs = msg.transforms
        if len(tfs) == 0:
            return topic, msg, stamp, header, tags

        for i in reversed(range(len(tfs))):
            key = "%s@%s@%s" % (topic, tfs[i].header.frame_id, tfs[i].child_frame_id)
            msg_stamp = tfs[i].header.stamp

            if key not in self._last_msgs:
                self._last_msgs[key] = deque(maxlen=self._cache_size)
                self._last_msgs[key].append((msg_stamp, stamp, copy.deepcopy(tags)))
                continue

            same_msgs = [(ms, s, t) for ms, s, t in self._last_msgs[key] if ms == msg_stamp]
            has_seen_this_stamp = len(same_msgs) > 0

            stamp_diff_ok = True
            if self._max_ignored_duration is not None:
                _, last_msg_stamp, _ = max(same_msgs, key=lambda x: x[1])
                # if stamp would somehow be before last_msg_stamp, we want to succeed
                last_msg_stamp = min(stamp, last_msg_stamp)
                stamp_diff_ok = stamp - last_msg_stamp < self._max_ignored_duration

            if has_seen_this_stamp and stamp_diff_ok:
                del tfs[i]
            else:
                self._last_msgs[key].append((msg_stamp, stamp, copy.deepcopy(tags)))

        if len(tfs) == 0:
            return None

        return topic, msg, stamp, header, tags

    def reset(self):
        self._last_msgs.clear()
        super(DeduplicateTF, self).reset()

    def _str_params(self):
        parts = []
        parts.append('max_ignored_duration=%r' % (self._max_ignored_duration,))
        parent_params = super(DeduplicateTF, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class Remap(RawMessageFilter):
    """Remap topics."""

    def __init__(self, min_stamp=None, max_stamp=None, include_time_ranges=None, exclude_time_ranges=None,
                 include_tags=None, exclude_tags=None, **kwargs):
        """
        :param dict kwargs: The mapping to use. Keys are topics to be remapped, values are their new names.
        """
        super(Remap, self).__init__(kwargs.keys(), None, None, None, min_stamp, max_stamp,
                                    include_time_ranges, exclude_time_ranges, include_tags, exclude_tags)
        self.remap = kwargs

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        return self.remap.get(topic, topic), datatype, data, md5sum, pytype, stamp, header, tags

    def _str_params(self):
        return dict_to_str(self.remap, '=>')

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument(
            '--remap', nargs='+', metavar="FROM TO",
            help="Remap topics. This argument should be an even-sized list of pairs [FROM TO].")

    @staticmethod
    def process_cli_args(filters, args):
        if args.remap:
            topics_from = args.remap[0::2]
            topics_to = args.remap[1::2]
            filters.append(Remap(**dict(zip(topics_from, topics_to))))


class Copy(RawMessageFilter):
    """Copy topics to other topics."""

    def __init__(self, min_stamp=None, max_stamp=None, include_time_ranges=None, exclude_time_ranges=None,
                 include_tags=None, exclude_tags=None, add_tags=None, **kwargs):
        """
        :param set add_tags: If set, these tags will be added to the copies of messages.
        :param dict kwargs: The mapping to use. Keys are topics to be copied, values are their new names.
        """
        super(Copy, self).__init__(kwargs.keys(), None, None, None, min_stamp, max_stamp,
                                   include_time_ranges, exclude_time_ranges, include_tags, exclude_tags)
        self._add_tags = add_tags
        self.copy = dict()
        for topic_from, topic_to in kwargs.items():
            if not topic_from.startswith('/'):
                topic_from = '/' + topic_from
            self.copy[topic_from] = topic_to

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        copy_tags = tags_for_generated_msg(tags)
        if self._add_tags:
            copy_tags = copy_tags.union(self._add_tags)
        topic_from = topic
        if not topic_from.startswith('/'):
            topic_from = '/' + topic_from
        return [
            (topic, datatype, data, md5sum, pytype, stamp, header, tags),
            (self.copy[topic_from], (datatype, data, md5sum, pytype), stamp, header, copy_tags),
        ]

    def _str_params(self):
        return dict_to_str(self.copy, '=>')

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument(
            '--copy', nargs='+', metavar="FROM TO",
            help="Copy topics. This argument should be an even-sized list of pairs [FROM TO].")

    @staticmethod
    def process_cli_args(filters, args):
        if args.remap:
            topics_from = args.remap[0::2]
            topics_to = args.remap[1::2]
            filters.append(Copy(**dict(zip(topics_from, topics_to))))


class Throttle(RawMessageFilter):
    """Throttle messages on topics."""

    def __init__(self, min_stamp=None, max_stamp=None, include_time_ranges=None, exclude_time_ranges=None,
                 include_tags=None, exclude_tags=None, **kwargs):
        """
        :param kwargs: Keys are topics to be throttled, values are their maximum frequencies.
        :type kwargs: Dict[str, float]
        """
        super(Throttle, self).__init__(kwargs.keys(), None, None, None, min_stamp, max_stamp,
                                       include_time_ranges, exclude_time_ranges, include_tags, exclude_tags)
        self.hz = kwargs
        self.prev_t = {}

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        t = stamp.to_sec()
        if topic in self.prev_t:
            period = t - self.prev_t[topic]
            if period <= 0.:
                return None
            hz = 1. / period
            if hz > self.hz[topic]:
                return None
        self.prev_t[topic] = t
        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def reset(self):
        self.prev_t = {}
        super(Throttle, self).reset()

    def _str_params(self):
        return dict_to_str(self.hz, '@')

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument(
            '--throttle', '--hz', nargs='+', metavar="TOPIC RATE",
            help="Throttle messages. This argument should be an even-sized list of pairs [TOPIC RATE].")

    @staticmethod
    def yaml_config_args():
        return 'throttle',

    @staticmethod
    def process_cli_args(filters, args):
        if args.throttle:
            topics = args.throttle[0::2]
            hz = [float(hz) for hz in args.throttle[1::2]]
            filters.append(Throttle(**dict(zip(topics, hz))))


class Topics(RawMessageFilter):
    """Select topics that will be retained or removed. This works as a global filter."""

    def __init__(self, include_topics=None, exclude_topics=None):
        """
        :param list include_topics: If nonempty, all topics not on this list will be dropped.
        :param list exclude_topics: If nonempty, all topics on this list will be dropped.
        """
        super(Topics, self).__init__()
        # do not use _include_topics and _exclude_topics here as they would not allow the filter to reject the messages
        self.include = TopicSet(include_topics)
        self.exclude = TopicSet(exclude_topics)

    def topic_filter(self, topic):
        if self.include and topic not in self.include:
            return False
        if self.exclude and topic in self.exclude:
            return False
        return True

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        # The filtering is done in topic_filter()
        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def _str_params(self):
        params = []
        if self.include:
            params.append('include_topics=' + str(self.include))
        if self.exclude:
            params.append('exclude_topics=' + str(self.exclude))
        return ", ".join(params)

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument('-i', '--include-topics', nargs='+', help="Retain only these topics")
        parser.add_argument('-e', '--exclude-topics', nargs='+', help="Remove these topics")

    @staticmethod
    def yaml_config_args():
        return 'include_topics', 'exclude_topics'

    @staticmethod
    def process_cli_args(filters, args):
        if args.include_topics or args.exclude_topics:
            filters.append(Topics(include_topics=args.include_topics, exclude_topics=args.exclude_topics))


class TopicTypes(RawMessageFilter):
    """Select topic types that will be retained or removed. This works as a global filter."""

    def __init__(self, include_types=None, exclude_types=None):
        """
        :param list include_types: If nonempty, messages of types not on this list will be dropped.
        :param list exclude_types: If nonempty, messages of types on this list will be dropped.
        """
        super(TopicTypes, self).__init__()
        # do not use _include_types and _exclude_types here as they would not allow the filter to reject the messages
        self.include = TopicSet(include_types)
        self.exclude = TopicSet(exclude_types)

    def connection_filter(self, topic, datatype, md5sum, msg_def, header):
        if self.include and datatype not in self.include:
            return False
        if self.exclude and datatype in self.exclude:
            return False
        return True

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        # The filtering is done in connection_filter()
        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def _str_params(self):
        params = []
        if self.include:
            params.append('include_types=' + str(self.include))
        if self.exclude:
            params.append('exclude_types=' + str(self.exclude))
        return ", ".join(params)

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument('--include-types', nargs='+', help="Retain only messages of these types")
        parser.add_argument('--exclude-types', nargs='+', help="Remove messages of these types")

    @staticmethod
    def yaml_config_args():
        return 'include_types', 'exclude_types'

    @staticmethod
    def process_cli_args(filters, args):
        if args.include_types or args.exclude_types:
            filters.append(TopicTypes(include_types=args.include_types, exclude_types=args.exclude_types))


class Drop(RawMessageFilter):
    """Drop matching messages. The difference between Topics and Drop is that Drop acts locally, i.e. can be used in the
     middle of a filter chain."""

    def consider_message(self, topic, datatype, stamp, header, tags):
        # Drop also messages from extra time ranges
        return super(Drop, self).consider_message(topic, datatype, stamp, header, tags)

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        return None


class Transforms(DeserializedMessageFilter):
    """Filter or change transforms."""

    def __init__(self, include_parents=(), exclude_parents=(), include_children=(), exclude_children=(),
                 change=None, include_topics=None, exclude_topics=None, include_types=None, exclude_types=None,
                 *args, **kwargs):
        """
        :param list include_parents: If nonempty, only TFs with one of the listed frames as parent will be retained.
        :param list exclude_parents: If nonempty, TFs with one of the listed frames as parent will be dropped.
        :param list include_children: If nonempty, only TFs with one of the listed frames as child will be retained.
        :param list exclude_children: If nonempty, TFs with one of the listed frames as child will be dropped.
        :param dict change: A multilevel dictionary. First key is parent frame. Second key is child frame. Values are
                            dicts with optional keys 'translation', 'rotation', 'frame_id', 'child_frame_id'.
                            'translation' is a dict with optional keys 'x', 'y', 'z'.
                            'rotation' is a dict with optional keys 'x', 'y', 'z', 'ẅ́'.
                            If any of these optional keys is defined, it overwrites the values in the TF message.
        :param list include_topics: The topics on which this filter operates. Leave empty for the default set of topics.
        :param list exclude_topics: Do not operate on these topics.
        :param list include_types: Ignored.
        :param list exclude_types: Ignored.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(Transforms, self).__init__(
            include_topics=include_topics if include_topics is not None else ['/tf', '/tf_static'],
            exclude_topics=exclude_topics, include_types=['tf2_msgs/TFMessage'], *args, **kwargs)
        self.include_parents = TopicSet(include_parents)
        self.exclude_parents = TopicSet(exclude_parents)
        self.include_children = TopicSet(include_children)
        self.exclude_children = TopicSet(exclude_children)
        self.change = change if change is not None else {}
        self.changed_parents = TopicSet(self.change.keys())

    def filter(self, topic, msg, stamp, header, tags):
        num_tfs = len(msg.transforms)
        if self.include_parents:
            msg.transforms = [tf for tf in msg.transforms if tf.header.frame_id in self.include_parents]
            if len(msg.transforms) < num_tfs:
                tags.add(MessageTags.CHANGED)
                num_tfs = len(msg.transforms)
        if self.exclude_parents:
            msg.transforms = [tf for tf in msg.transforms if tf.header.frame_id not in self.exclude_parents]
            if len(msg.transforms) < num_tfs:
                tags.add(MessageTags.CHANGED)
                num_tfs = len(msg.transforms)
        if self.include_children:
            msg.transforms = [tf for tf in msg.transforms if tf.child_frame_id in self.include_children]
            if len(msg.transforms) < num_tfs:
                tags.add(MessageTags.CHANGED)
                num_tfs = len(msg.transforms)
        if self.exclude_children:
            msg.transforms = [tf for tf in msg.transforms if tf.child_frame_id not in self.exclude_children]
            if len(msg.transforms) < num_tfs:
                tags.add(MessageTags.CHANGED)
                num_tfs = len(msg.transforms)
        if len(self.change) > 0:
            for transform in msg.transforms:
                if transform.header.frame_id in self.changed_parents:
                    if transform.child_frame_id in self.change[transform.header.frame_id]:
                        changes = self.change[transform.header.frame_id][transform.child_frame_id]
                        t = transform.transform
                        if 'translation' in changes:
                            t.translation.x = float(changes['translation'].get('x', t.translation.x))
                            t.translation.y = float(changes['translation'].get('y', t.translation.y))
                            t.translation.z = float(changes['translation'].get('z', t.translation.z))
                            tags.add(MessageTags.CHANGED)
                        if 'rotation' in changes:
                            t.rotation.x = float(changes['rotation'].get('x', t.rotation.x))
                            t.rotation.y = float(changes['rotation'].get('y', t.rotation.y))
                            t.rotation.z = float(changes['rotation'].get('z', t.rotation.z))
                            t.rotation.w = float(changes['rotation'].get('w', t.rotation.w))
                            tags.add(MessageTags.CHANGED)
                        if 'frame_id' in changes:
                            transform.header.frame_id = changes['frame_id']
                            tags.add(MessageTags.CHANGED)
                        if 'child_frame_id' in changes:
                            transform.child_frame_id = changes['child_frame_id']
                            tags.add(MessageTags.CHANGED)
        if not msg.transforms:
            return None
        return topic, msg, stamp, header, tags

    def _str_params(self):
        parts = []
        if self.include_parents:
            parts.append('include_parents=' + str(self.include_parents))
        if self.exclude_parents:
            parts.append('exclude_parents=' + str(self.exclude_parents))
        if self.include_children:
            parts.append('include_children=' + str(self.include_children))
        if self.exclude_children:
            parts.append('exclude_children=' + str(self.exclude_children))
        if len(self.change) > 0:
            parts.append('change=' + dict_to_str(self.change))
        parent_params = super(Transforms, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument('--include-tf-parents', nargs='+', help="Retain only TFs with these frames as parents")
        parser.add_argument('--exclude-tf-parents', nargs='+', help="Remove TFs with these frames as parents")
        parser.add_argument('--include-tf-children', nargs='+', help="Retain only TFs with these frames as children")
        parser.add_argument('--exclude-tf-children', nargs='+', help="Remove TFs with these frames as children")

    @staticmethod
    def yaml_config_args():
        return 'include_tf_parents', 'exclude_tf_parents', 'include_tf_children', 'exclude_tf_children'

    @staticmethod
    def process_cli_args(filters, args):
        if args.include_tf_parents or args.exclude_tf_parents or args.include_tf_children or args.exclude_tf_children:
            filters.append(Transforms(
                include_parents=args.include_tf_parents, exclude_parents=args.exclude_tf_parents,
                include_children=args.include_tf_children, exclude_children=args.exclude_tf_children))


class MergeInitialStaticTf(DeserializedMessageFilter):
    """Merge all /tf_static messages from the beginning of bag files into a single message."""

    def __init__(self, delay=5.0, change_stamps=True, add_tags=None):
        super(MergeInitialStaticTf, self).__init__(include_topics=["/tf_static"], include_types=["tf2_msgs/TFMessage"])
        self.delay = rospy.Duration(delay)
        self.start_time = None
        self.end_time = None
        self.change_stamps = change_stamps
        self.merged_message_published = False
        self.merged_transforms = {}
        self.add_tags = add_tags

    def set_multibag(self, bag):
        super(MergeInitialStaticTf, self).set_multibag(bag)
        self.start_time = rospy.Time(bag.get_start_time())
        self.end_time = self.start_time + self.delay
        for topic, msg, stamp in bag.read_messages(
                topics=['/tf_static'], start_time=self.start_time, end_time=self.end_time):
            for tf in msg.transforms:
                if self.change_stamps:
                    tf = copy.deepcopy(tf)
                    tf.header.stamp = self.start_time
                self.merged_transforms[tf.child_frame_id] = tf

    def consider_message(self, topic, datatype, stamp, header, tags):
        if self.end_time is None or self.end_time <= stamp:
            return False
        return super(MergeInitialStaticTf, self).consider_message(topic, datatype, stamp, header, tags)

    def filter(self, topic, msg, stamp, header, tags):
        if self.merged_message_published:
            return None
        self.merged_message_published = True
        merged_tags = tags_for_generated_msg(tags)
        if self.add_tags:
            merged_tags = merged_tags.union(self.add_tags)
        if self.change_stamps:
            stamp = self.start_time
        return topic, TFMessage(list(self.merged_transforms.values())), stamp, header, merged_tags

    def reset(self):
        self.merged_message_published = False
        self.merged_transforms.clear()
        self.start_time = None
        self.end_time = None
        super(MergeInitialStaticTf, self).reset()

    def _str_params(self):
        parts = ['delay=%f' % (self.delay.to_sec(),), 'change_stamps=' + repr(self.change_stamps)]
        parent_params = super(MergeInitialStaticTf, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument('--merge-initial-static-tf', nargs='?', required=False, type=float, metavar="DURATION",
                            help="Merge a few initial static TFs into one. DURATION specifies the duration of the "
                                 "initial bag section to be considered for the merging. DURATION defaults to 5 secs.")

    @staticmethod
    def yaml_config_args():
        return ['merge_initial_static_tf']

    @staticmethod
    def process_cli_args(filters, args):
        if hasattr(args, 'merge_initial_static_tf') and args.merge_initial_static_tf is not None:
            if isinstance(args.merge_initial_static_tf, bool):
                if args.merge_initial_static_tf:
                    filters.append(MergeInitialStaticTf())
            else:
                filters.append(MergeInitialStaticTf(args.merge_initial_static_tf))


class CopyFirstTfToBagStart(NoMessageFilter):
    """For dynamic TFs, copies the first transform of each frame to the very start of the bag."""

    def __init__(self, max_delay=5.0, add_tags=None, tf_topic="/tf"):
        super(CopyFirstTfToBagStart, self).__init__()

        self.max_delay = rospy.Duration(max_delay)
        self.start_time = None
        self.max_time = None
        self.tf_topic = tf_topic
        self.merged_transforms = {}
        self.add_tags = add_tags

    def set_bag(self, bag):
        super(CopyFirstTfToBagStart, self).set_bag(bag)
        self.start_time = rospy.Time(bag.get_start_time())
        self.max_time = self.start_time + self.max_delay
        for topic, msg, stamp in bag.read_messages(
                topics=[self.tf_topic], start_time=self.start_time, end_time=self.max_time):
            for tf in msg.transforms:
                if tf.child_frame_id in self.merged_transforms:
                    continue
                tf = copy.deepcopy(tf)
                tf.header.stamp = self.start_time
                self.merged_transforms[tf.child_frame_id] = tf

    def extra_initial_messages(self):
        merged_tags = {MessageTags.GENERATED}
        if self.add_tags:
            merged_tags = merged_tags.union(self.add_tags)
        header = create_connection_header(self.tf_topic, TFMessage, latch=False)
        return [(self.tf_topic, TFMessage(list(self.merged_transforms.values())), self.start_time, header, merged_tags)]

    def reset(self):
        self.merged_transforms.clear()
        self.start_time = None
        self.max_time = None
        super(CopyFirstTfToBagStart, self).reset()

    def _str_params(self):
        parts = ['max_delay=%f' % (self.max_delay.to_sec(),), 'tf_topic=' + self.tf_topic]
        parent_params = super(CopyFirstTfToBagStart, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class RemoveInvalidTF(DeserializedMessageFilter):
    """Remove all invalid TFs."""

    def __init__(self, include_topics=None, *args, **kwargs):
        """
        :param list include_topics: Topics to work on (defaults to standard TF topics).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(RemoveInvalidTF, self).__init__(
            include_topics=['tf', 'tf_static'] if include_topics is None else include_topics,
            include_types=[TFMessage._type], *args, **kwargs)  # noqa

    def filter(self, topic, msg, stamp, header, tags):
        tfs = msg.transforms
        if len(tfs) == 0:
            return topic, msg, stamp, header, tags

        for i in reversed(range(len(tfs))):
            tf_msg = tfs[i]
            valid = True

            valid = valid and len(tf_msg.header.frame_id) > 0
            valid = valid and len(tf_msg.child_frame_id) > 0

            translation = tf_msg.transform.translation
            valid = valid and math.isfinite(translation.x)
            valid = valid and math.isfinite(translation.y)
            valid = valid and math.isfinite(translation.z)

            rot = tf_msg.transform.rotation
            valid = valid and math.isfinite(rot.x)
            valid = valid and math.isfinite(rot.y)
            valid = valid and math.isfinite(rot.z)
            valid = valid and math.isfinite(rot.w)

            valid = valid and (abs(pow(rot.x, 2) + pow(rot.y, 2) + pow(rot.z, 2) + pow(rot.w, 2) - 1) < 1e-6)

            if not valid:
                del tfs[i]

        if len(tfs) == 0:
            return None

        return topic, msg, stamp, header, tags


class MagnetometerGaussToTesla(DeserializedMessageFilter):
    """Scale magnetometer data from Gauss to Tesla (e.g. XSens IMUs)."""

    def __init__(self, scale_covariance=True, include_topics=None, add_tags=None, *args, **kwargs):
        """
        :param bool scale_covariance: If true, covariance will be also scaled by 10^-8.
        :param list include_topics: Topics to work on (defaults to standard TF topics).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(MagnetometerGaussToTesla, self).__init__(
            include_topics=['imu/mag'] if include_topics is None else include_topics,
            include_types=[MagneticField._type], *args, **kwargs)  # noqa

        self._scale_covariance = scale_covariance
        self._add_tags = set(add_tags) if add_tags is not None else None

    def filter(self, topic, msg, stamp, header, tags):
        coef = 1e-4
        msg.magnetic_field.x *= coef
        msg.magnetic_field.y *= coef
        msg.magnetic_field.z *= coef

        if self._scale_covariance:
            coef = pow(coef, 2)
            msg.magnetic_field_covariance = [f * coef for f in msg.magnetic_field_covariance]

        return topic, msg, stamp, header, tags_for_changed_msg(tags, self._add_tags)


def set_transform_from_KDL_frame(transform, frame):
    translation = frame.p
    rotation = frame.M.GetQuaternion()
    transform.translation.x = translation.x()
    transform.translation.y = translation.y()
    transform.translation.z = translation.z()
    transform.rotation.x = rotation[0]
    transform.rotation.y = rotation[1]
    transform.rotation.z = rotation[2]
    transform.rotation.w = rotation[3]


class RecomputeTFFromJointStates(DeserializedMessageFilter):
    """Recompute some TFs from the URDF model and joint states."""

    def __init__(self, include_topics=None, description_param=None, description_file=None, joint_state_cache_size=100,
                 include_parents=(), exclude_parents=(), include_children=(), exclude_children=(), include_joints=(),
                 exclude_joints=(), publish_new_tfs=False, discard_failed_transforms=True, add_tags=None,
                 *args, **kwargs):
        """
        :param list include_topics: Topics to handle. It should contain both the JointState and TF topics. If no TF
                                    topic is given, /tf and /tf_static are added automatically.
        :param str description_param: Name of the ROS parameter that hold the URDF model.
        :param str description_file: Path to a file with the URDF model (has precedence over description_param).
        :param joint_state_cache_size: Number of JointState messages that will be cached to be searchable when a TF
                                       message comes and needs to figure out the JointState message it was created
                                       from.
        :param list include_parents: If nonempty, only TFs with one of the listed frames as parent will be processed.
        :param list exclude_parents: If nonempty, TFs with one of the listed frames as parent will not be processed.
        :param list include_children: If nonempty, only TFs with one of the listed frames as child will be processed.
        :param list exclude_children: If nonempty, TFs with one of the listed frames as child will not be processed.
        :param list include_joints: If nonempty, only joints with the listed named will be processed.
        :param list exclude_joints: If nonempty, joints with one of the listed names will not be processed.
        :param list publish_new_tfs: If true, the mode of operation is changed so that every joint_states message
                                     will create a new TF message for the contained joints (i.e. new messages are added
                                     to the bag).
        :param bool discard_failed_transforms: If true and a transform fails to get recomputed, it will be discarded.
                                               If false, original transform is passed instead of the recomputed one.
        :param set add_tags: Add these tags to all modified TF messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        if include_topics is None:
            include_topics = ("/joint_states",)
        has_tf = False
        for topic in include_topics:
            if "tf" in topic:
                has_tf = True
                break
        if not has_tf:
            include_topics = list(include_topics)
            include_topics.append("/tf")
            include_topics.append("/tf_static")

        super(RecomputeTFFromJointStates, self).__init__(
            include_topics=include_topics, include_types=(TFMessage._type, JointState._type), *args, **kwargs)  # noqa

        self._urdf_model = None
        self._kdl_model = None
        self._mimic_joints = {}
        self._segments = {}
        self._segments_fixed = {}
        self._child_to_joint_map = {}
        self._joint_to_child_map = {}
        self._joint_state_cache = deque(maxlen=joint_state_cache_size)
        self._discard_failed_transforms = discard_failed_transforms
        self._add_tags = add_tags

        self.include_parents = TopicSet(include_parents)
        self.exclude_parents = TopicSet(exclude_parents)
        self.include_children = TopicSet(include_children)
        self.exclude_children = TopicSet(exclude_children)
        self.include_joints = TopicSet(include_joints)
        self.exclude_joints = TopicSet(exclude_joints)
        self.publish_new_tfs = publish_new_tfs

        self.description_param = description_param

        self._model_from_file = None
        if description_file is not None:
            if os.path.exists(description_file):
                try:
                    with open(description_file, "r") as f:
                        self._model_from_file = f.read()
                    print("Loaded URDF model from file " + description_file)
                except Exception as e:
                    print('Could not read URDF model from file %s: %s' % (description_file, str(e)), file=sys.stderr)
            else:
                print('URDF model file "%s" does not exist.' % (description_file,), file=sys.stderr)

    def set_bag(self, bag):
        super(RecomputeTFFromJointStates, self).set_bag(bag)

        model = self._model_from_file
        if model is None:
            if self.description_param is None:
                self.description_param = "robot_description"
            model = self._get_param(self.description_param)
            if model is not None:
                print("Read URDF model from ROS parameter " + self.description_param)

        if model is None:
            raise RuntimeError("RecomputeTFFromJointStates requires robot URDF either as a file or ROS parameter.")

        self._urdf_model = urdf.URDF.from_xml_string(model)
        ok, self._kdl_model = treeFromUrdfModel(self._urdf_model, quiet=True)
        if not ok:
            self._kdl_model = None
            print('Could not parse URDF model into KDL model.', file=sys.stderr)

        self._mimic_joints = {}
        for joint_name, joint in self._urdf_model.joint_map.items():
            if joint.mimic is not None:
                self._mimic_joints[joint_name] = joint.mimic

        self._segments = {}
        self._segments_fixed = {}
        self._child_to_joint_map = {}
        self._joint_to_child_map = {}
        for child, (joint_name, parent) in self._urdf_model.parent_map.items():
            if self.include_joints and joint_name not in self.include_joints:
                continue
            if self.exclude_joints and joint_name in self.exclude_joints:
                continue
            segment = self._kdl_model.getChain(parent, child).getSegment(0)
            joint = segment.getJoint()
            self._child_to_joint_map[child] = joint_name
            self._joint_to_child_map[joint_name] = child
            if joint.getTypeName() == "None":
                self._segments_fixed[child] = segment
            else:
                self._segments[child] = segment

    def filter(self, topic, msg, stamp, header, tags):
        if self._kdl_model is None:
            return topic, msg, stamp, header, tags

        changed_tags = {MessageTags.CHANGED}
        if self._add_tags:
            changed_tags = changed_tags.union(self._add_tags)

        if msg._type == TFMessage._type:  # noqa
            for i in reversed(range(len(msg.transforms))):
                transform = msg.transforms[i]
                if self.include_parents and transform.header.frame_id not in self.include_parents:
                    continue
                if self.exclude_parents and transform.header.frame_id in self.exclude_parents:
                    continue
                if self.include_children and transform.child_frame_id not in self.include_children:
                    continue
                if self.exclude_children and transform.child_frame_id in self.exclude_children:
                    continue

                is_dynamic = transform.child_frame_id in self._segments
                is_static = transform.child_frame_id in self._segments_fixed
                if not is_static and not is_dynamic:
                    continue  # not a TF from the robot model

                if is_static:
                    success = self._recompute_static_transform(transform)
                else:
                    success = self._recompute_dynamic_transform(transform)
                if not success:
                    print("Could not recompute transform %s -> %s at time %s." % (
                        transform.header.frame_id, transform.child_frame_id, to_str(stamp)), file=sys.stderr)
                    if self._discard_failed_transforms:
                        del msg.transforms[i]
                        tags = changed_tags
                else:
                    tags = changed_tags

            if len(msg.transforms) == 0:
                return None

            return topic, msg, stamp, header, tags

        else:  # the message is a joint state
            self._joint_state_cache.append(copy.deepcopy(msg))
            result = [(topic, msg, stamp, header, tags)]

            if self.publish_new_tfs:
                tf_msg = TFMessage()
                tf_static_msg = TFMessage()
                for i in range(len(msg.name)):
                    joint_name = msg.name[i]
                    if len(msg.position) <= i or not math.isfinite(msg.position[i]):
                        continue
                    if joint_name not in self._joint_to_child_map:
                        continue
                    child_name = self._joint_to_child_map[joint_name]
                    _, parent_name = self._urdf_model.parent_map[child_name]
                    t = TransformStamped()
                    t.header.frame_id = parent_name
                    t.header.stamp = msg.header.stamp
                    t.child_frame_id = child_name
                    if child_name in self._segments:
                        tf_msg.transforms.append(t)
                    else:
                        tf_static_msg.transforms.append(t)
                    # Append just zero-initialized TF; it will be passed to this filter later and it will get recomputed

                tf_tags = tags_for_generated_msg(tags)
                if self._add_tags:
                    tf_tags = tf_tags.union(self._add_tags)
                if len(tf_msg.transforms) > 0:
                    connection_header = create_connection_header("/tf", TFMessage, False)
                    result.append(("/tf", tf_msg, stamp, connection_header, tf_tags))
                if len(tf_static_msg.transforms) > 0:
                    connection_header = create_connection_header("/tf_static", TFMessage, True)
                    result.append(("/tf_static", tf_msg, stamp, connection_header, tf_tags))

            return result

    def _recompute_static_transform(self, transform):
        if transform.child_frame_id not in self._segments_fixed:
            return False
        segment = self._segments_fixed[transform.child_frame_id]
        frame = segment.pose(0.0)
        set_transform_from_KDL_frame(transform.transform, frame)
        return True

    def _recompute_dynamic_transform(self, transform):
        if transform.child_frame_id not in self._child_to_joint_map or transform.child_frame_id not in self._segments:
            return False
        if transform.child_frame_id not in self._segments:
            return False

        stamp = transform.header.stamp
        joint_name = self._child_to_joint_map[transform.child_frame_id]
        mimic = None
        if joint_name in self._mimic_joints:
            mimic = self._mimic_joints[joint_name]
            joint_name = mimic.joint
        for msg in reversed(self._joint_state_cache):
            if msg.header.stamp != stamp:
                continue
            if len(msg.position) == 0:
                continue
            for i in range(len(msg.name)):
                if len(msg.position) < i - 1:
                    break
                if msg.name[i] == joint_name:
                    pos = msg.position[i]
                    if mimic is not None:
                        pos = pos * float(mimic.multiplier) + float(mimic.offset)
                    segment = self._segments[transform.child_frame_id]
                    frame = segment.pose(pos)
                    set_transform_from_KDL_frame(transform.transform, frame)
                    return True
        return False

    def _str_params(self):
        parts = []
        if self.include_parents:
            parts.append('include_parents=' + str(self.include_parents))
        if self.exclude_parents:
            parts.append('exclude_parents=' + str(self.exclude_parents))
        if self.include_children:
            parts.append('include_children=' + str(self.include_children))
        if self.exclude_children:
            parts.append('exclude_children=' + str(self.exclude_children))
        if self.include_joints:
            parts.append('include_joints=' + str(self.include_joints))
        if self.exclude_joints:
            parts.append('exclude_joints=' + str(self.exclude_joints))
        if self.publish_new_tfs:
            parts.append("publish_new_tfs")
        if self._discard_failed_transforms:
            parts.append("discard_failed_transforms")
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class FixJointStates(DeserializedMessageFilter):
    """Adjust some joint states. If TFs are computed from them, use RecomputeTFFromJointStates to recompute TF."""

    class _JointStateType(Enum):
        POSITION = 1
        VELOCITY = 2
        EFFORT = 3
        INTEGRATE_POSITION = 4
        DERIVATE_VELOCITY = 5

    class _OperationType(Enum):
        VALUE = 1
        MULTIPLIER = 2
        OFFSET = 3
        NORMALIZE_ANGLE = 4
        NORMALIZE_ANGLE_POSITIVE = 5

    def __init__(self, changes=None, add_tags=None, *args, **kwargs):
        # type: (Optional[Dict[STRING_TYPE, Dict[STRING_TYPE, Dict[STRING_TYPE, Any]]]], Optional[Set[STRING_TYPE]], Any, Any) -> None  # noqa
        """
        :param changes: Dict specifying what should be changed. Keys are joint names. Values are dicts with possible
                        keys "position", "velocity", "effort", "integrate_position", "derivate_velocity".
                        Values of the first 3 dicts can be
                        "value" (set absolute value), "offset" (add offset), "multiplier" (multiply value),
                        "normalize_angle" (normalizes value to (-pi, pi)), "normalize_angle_positive" (normalizes
                        value to (0, 2*pi)).
                        If "value" is specified, then "offset" and "multiplier" are ignored. Otherwise, "multiplier" is
                        applied first and "offset" is added to the result.
                        "integrate_position" dict can contain keys "min_dt". "min_dt" specifies the minimum time delta
                        between consecutive joint states to consider them an update. It defaults to 0.01 s.
                        The first position when integrating position is the normal value contained in the message (it
                        can be influenced by the specified "position" change (e.g. set to 0)).
                        "derivate_velocity" behaves similar to "integrate_position", but does the opposite computation.
        :param add_tags: Tags to be added to modified joint states messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(FixJointStates, self).__init__(include_types=(JointState._type,), *args, **kwargs)  # noqa

        JointStateType = FixJointStates._JointStateType
        OperationType = FixJointStates._OperationType

        def convert_operations(operations):  # type: (Dict[STRING_TYPE, float]) -> Dict[OperationType, float]
            result = {}
            if "value" in operations:
                result[OperationType.VALUE] = float(operations["value"])
            if "multiplier" in operations:
                result[OperationType.MULTIPLIER] = float(operations["multiplier"])
            if "offset" in operations:
                result[OperationType.OFFSET] = float(operations["offset"])
            if "normalize_angle" in operations:
                result[OperationType.NORMALIZE_ANGLE] = bool(operations["normalize_angle"])
            if "normalize_angle_positive" in operations:
                result[OperationType.NORMALIZE_ANGLE_POSITIVE] = bool(operations["normalize_angle_positive"])
            return result

        self._changes = {}
        if changes is not None:
            for joint_name, change in changes.items():
                self._changes[joint_name] = {}
                if "position" in change:
                    self._changes[joint_name][JointStateType.POSITION] = convert_operations(change["position"])
                if "velocity" in change:
                    self._changes[joint_name][JointStateType.VELOCITY] = convert_operations(change["velocity"])
                if "effort" in change:
                    self._changes[joint_name][JointStateType.EFFORT] = convert_operations(change["effort"])
                if "integrate_position" in change:
                    self._changes[joint_name][JointStateType.INTEGRATE_POSITION] = change["integrate_position"]
                if "derivate_velocity" in change:
                    self._changes[joint_name][JointStateType.DERIVATE_VELOCITY] = change["derivate_velocity"]
        self._changed_joint_names = TopicSet(self._changes.keys())  # for efficient filtering

        self._add_tags = add_tags

        self._integrated_positions = {}
        self._last_velocities = {}
        self._last_integration_stamps = {}

        self._last_positions = {}
        self._derived_velocities = {}
        self._last_derivation_stamps = {}

    def filter(self, topic, msg, stamp, header, tags):
        JointStateType = FixJointStates._JointStateType

        changed_tags = copy.deepcopy(tags)
        changed_tags.add(MessageTags.CHANGED)
        if self._add_tags:
            changed_tags = changed_tags.union(self._add_tags)

        for i in range(len(msg.name)):
            name = msg.name[i]
            if name not in self._changed_joint_names:
                continue
            changes = self._changes[name]

            if len(msg.position) > i and JointStateType.POSITION in changes:
                if isinstance(msg.position, tuple):
                    msg.position = list(msg.position)
                msg.position[i] = self._apply_changes(msg.position[i], changes[JointStateType.POSITION])
                tags = changed_tags

            if len(msg.velocity) > i and JointStateType.VELOCITY in changes:
                if isinstance(msg.velocity, tuple):
                    msg.velocity = list(msg.velocity)
                msg.velocity[i] = self._apply_changes(msg.velocity[i], changes[JointStateType.VELOCITY])
                tags = changed_tags

            if len(msg.effort) > i and JointStateType.EFFORT in changes:
                if isinstance(msg.effort, tuple):
                    msg.effort = list(msg.effort)
                msg.effort[i] = self._apply_changes(msg.effort[i], changes[JointStateType.EFFORT])
                tags = changed_tags

            if len(msg.position) > i and len(msg.velocity) > i and JointStateType.INTEGRATE_POSITION in changes:
                if name not in self._last_integration_stamps:
                    self._last_integration_stamps[name] = msg.header.stamp
                    self._integrated_positions[name] = msg.position[i]
                    self._last_velocities[name] = msg.velocity[i]
                else:
                    min_dt = changes[JointStateType.INTEGRATE_POSITION].get("min_dt", 0.01)
                    dt = (msg.header.stamp - self._last_integration_stamps[name]).to_sec()
                    if dt >= min_dt:
                        vel = (msg.velocity[i] + self._last_velocities[name]) / 2.0
                        distance = vel * dt
                        self._integrated_positions[name] += distance
                        if changes[JointStateType.INTEGRATE_POSITION].get("normalize_angle", False):
                            self._integrated_positions[name] = normalize_angle(self._integrated_positions[name])
                        if changes[JointStateType.INTEGRATE_POSITION].get("normalize_angle_positive", False):
                            self._integrated_positions[name] = normalize_angle_positive(
                                self._integrated_positions[name])
                        self._last_velocities[name] = msg.velocity[i]
                        self._last_integration_stamps[name] = msg.header.stamp
                    if isinstance(msg.position, tuple):
                        msg.position = list(msg.position)
                    msg.position[i] = self._integrated_positions[name]
                    tags = changed_tags

            if len(msg.position) > i and len(msg.velocity) > i and JointStateType.DERIVATE_VELOCITY in changes:
                if name not in self._last_derivation_stamps:
                    if math.isfinite(msg.position[i]):
                        self._last_derivation_stamps[name] = msg.header.stamp
                        self._derived_velocities[name] = msg.velocity[i]
                        self._last_positions[name] = msg.position[i]
                else:
                    min_dt = changes[JointStateType.DERIVATE_VELOCITY].get("min_dt", 0.01)
                    dt = (msg.header.stamp - self._last_derivation_stamps[name]).to_sec()
                    if dt >= min_dt:
                        distance = msg.position[i] - self._last_positions[name]
                        vel = distance / dt
                        self._derived_velocities[name] = vel
                        self._last_positions[name] = msg.position[i]
                        self._last_derivation_stamps[name] = msg.header.stamp
                    if isinstance(msg.velocity, tuple):
                        msg.velocity = list(msg.velocity)
                    msg.velocity[i] = self._derived_velocities[name]
                    tags = changed_tags

        return topic, msg, stamp, header, tags

    def _apply_changes(self, value, changes):
        if len(changes) == 0:
            return value

        OperationType = FixJointStates._OperationType

        result = value
        if OperationType.VALUE in changes:
            result = changes[OperationType.VALUE]
        else:
            result = result * changes.get(OperationType.MULTIPLIER, 1.0) + changes.get(OperationType.OFFSET, 0.0)

        if changes.get(OperationType.NORMALIZE_ANGLE, False):
            result = normalize_angle(result)
        if changes.get(OperationType.NORMALIZE_ANGLE_POSITIVE, False):
            result = normalize_angle_positive(result)

        return result

    def reset(self):
        self._integrated_positions = {}
        self._last_velocities = {}
        self._last_integration_stamps = {}
        super(FixJointStates, self).reset()

    def _str_params(self):
        parts = []
        parts.append('changes=' + dict_to_str(self._changes))
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class InterpolateJointStates(DeserializedMessageFilter):
    """Interpolate some joint states if their frequency is too low or some messages are missing."""

    instance_id = 0

    def __init__(self, joints, max_dt, dt_tolerance=0.05, include_topics=None, add_tags=None, exclude_tags=None,
                 *args, **kwargs):
        # type: (List[STRING_TYPE], float, float, STRING_TYPE, Optional[Set[STRING_TYPE]], List[Any], Any, Any) -> None
        """
        :param joints: The joints to interpolate. If empty, interpolate all joints.
        :param max_dt: The desired time difference between two consecutive messages.
        :param dt_tolerance: If the actual time delta is within this number from max_dt, no interpolation is done.
        :param include_topics: Topics with joint states (defaults to ['joint_states']).
        :param add_tags: Tags to be added to generated joint states messages.
        :param list exclude_tags: If nonempty, the filter will skip messages with these tags. Each element of
                                  this list is itself a set. For a message to be skipped, at least one set has to be
                                  a subset of the tags list of the message.
                                  This filter automatically adds a tag that prevents an instance of the filter to
                                  process messages generated by itself.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        InterpolateJointStates.instance_id += 1
        self._self_tag = "interpolate_joint_states_" + str(InterpolateJointStates.instance_id)
        _exclude_tags = [({t} if isinstance(t, STRING_TYPE) else set(t)) for t in exclude_tags] if exclude_tags else []
        _exclude_tags.append({self._self_tag})

        super(InterpolateJointStates, self).__init__(
            include_topics=('joint_states',) if include_topics is None else include_topics,
            include_types=(JointState._type,), exclude_tags=tuple(_exclude_tags), *args, **kwargs)  # noqa

        self._joints = TopicSet(joints)
        self._max_dt = rospy.Duration(max_dt)
        self._dt_tolerance = rospy.Duration(dt_tolerance)

        self._add_tags = ((set(add_tags) if add_tags else set()).union({self._self_tag}))

        self._last_states = {}

    def filter(self, topic, msg, stamp, header, tags):
        gen_tags = tags_for_generated_msg(tags)
        if self._add_tags:
            gen_tags = gen_tags.union(self._add_tags)

        gen_states = []

        for i in range(len(msg.name)):
            name = msg.name[i]
            if self._joints and name not in self._joints:
                continue

            state = self._get_state(msg, i)

            if name not in self._last_states:
                self._last_states[name] = state
                continue

            prev_state = self._last_states[name]
            self._last_states[name] = state

            dt = state[0] - prev_state[0]
            if dt < self._max_dt + self._dt_tolerance:
                continue

            gen_stamp = prev_state[0] + self._max_dt
            while gen_stamp + self._dt_tolerance < state[0]:
                gen_states.append(self._interpolate(gen_stamp, prev_state, state))
                gen_stamp = gen_stamp + self._max_dt

        gen_stamps = sorted(list(set([s[0] for s in gen_states])))
        gen_msgs = []
        for gen_stamp in gen_stamps:
            states = [s for s in gen_states if s[0] == gen_stamp]
            gen_msg = JointState()
            gen_msg.header.frame_id = msg.header.frame_id
            gen_msg.header.stamp = gen_stamp
            for s in states:
                gen_msg.name.append(s[1])
                if s[2] is not None:
                    gen_msg.position.append(s[2])
                if s[3] is not None:
                    gen_msg.velocity.append(s[3])
                if s[4] is not None:
                    gen_msg.effort.append(s[4])
            gen_msgs.append((topic, gen_msg, gen_stamp, header, gen_tags))

        return [(topic, msg, stamp, header, tags)] + gen_msgs

    def _get_state(self, msg, i):
        return (
            msg.header.stamp,
            msg.name[i],
            msg.position[i] if len(msg.position) > i else None,
            msg.velocity[i] if len(msg.velocity) > i else None,
            msg.effort[i] if len(msg.effort) > i else None,
        )

    def _interpolate(self, stamp, prev_state, state):
        # TODO support proper angular interpolation
        ratio = (stamp - prev_state[0]).to_sec() / (state[0] - prev_state[0]).to_sec()
        return (
            stamp,
            state[1],
            prev_state[2] + ratio * (state[2] - prev_state[2]),
            prev_state[3] + ratio * (state[3] - prev_state[3]),
            prev_state[4] + ratio * (state[4] - prev_state[4]),
        )

    def reset(self):
        self._last_states = {}
        super(InterpolateJointStates, self).reset()

    def _str_params(self):
        parts = []
        parts.append('joints=' + repr(self._joints))
        parts.append('max_dt=' + to_str(self._max_dt).rstrip('0'))
        parts.append('dt_tolerance=' + to_str(self._dt_tolerance).rstrip('0'))
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class ExtractJointStatesVelocityAsCommands(DeserializedMessageFilter):
    """Extract joint states velocity to pretend they are velocity commands."""

    def __init__(self, joints, cmd_topic, min_velocity=0.0, joint_states_topic="joint_states", add_tags=None,
                 *args, **kwargs):
        # type: (List[STRING_TYPE], STRING_TYPE, float, STRING_TYPE, Optional[Set[STRING_TYPE]], Any, Any) -> None
        """
        :param joints: The joints to extract.
        :param cmd_topic: The topic to which the extracted commands should be published.
        :param joint_states_topic: The topic on which to look for joint states.
        :param add_tags: Tags to be added to generated messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExtractJointStatesVelocityAsCommands, self).__init__(
            include_topics=(joint_states_topic,), include_types=(JointState._type,), *args, **kwargs)  # noqa

        self._joints = list(joints)
        self._cmd_topic = cmd_topic
        self._min_velocity = min_velocity
        self._joint_states_topic = joint_states_topic

        self._add_tags = set(add_tags) if add_tags else set()

    def filter(self, topic, msg, stamp, header, tags):
        out_msg = Float64MultiArray()
        for joint in self._joints:
            if joint not in msg.name:
                return topic, msg, stamp, header, tags
            joint_idx = msg.name.index(joint)
            if len(msg.velocity) <= joint_idx:
                return topic, msg, stamp, header, tags
            velocity = msg.velocity[joint_idx]
            if abs(velocity) < self._min_velocity:
                velocity = 0.0
            out_msg.data.append(velocity)

        gen_tags = tags_for_generated_msg(tags)
        if self._add_tags:
            gen_tags = gen_tags.union(self._add_tags)

        conn_header = create_connection_header(self._cmd_topic, Float64MultiArray)
        return [
            (topic, msg, stamp, header, tags),
            (self._cmd_topic, out_msg, stamp, conn_header, gen_tags),
        ]

    def _str_params(self):
        parts = []
        parts.append('joints=' + repr(self._joints))
        parts.append('cmd_topic=' + self._cmd_topic)
        if self._min_velocity != 0.0:
            parts.append('min_velocity=' + str(self._min_velocity))
        parts.append('joint_states_topic=' + self._joint_states_topic)
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class RecomputeAckermannOdometry(DeserializedMessageFilter):
    """Compute Ackermann-style odometry from joint states."""

    def __init__(self, wheel_radius, wheel_separation, traction_joint, steering_joint, joint_state_topic="joint_states",
                 frame_id="base_link", odom_frame_id="odom", odom_topic="odom", cmd_vel_topic=None,
                 cmd_vel_frame_id=None, min_dt=0.01, heading_change_coef=1.0, twist_covariance=None, add_tags=None,
                 *args, **kwargs):
        """
        :param float wheel_radius: Radius of the traction wheel [m].
        :param float wheel_separation: Separation of the steering and traction axles (wheelbase) [m].
        :param str traction_joint: Joint whose velocity will be used as traction velocity.
        :param str steering_joint: Joint whose position will be used as steering angle.
        :param str joint_state_topic: Topic with joint states.
        :param str frame_id: Child frame ID of the odom messages.
        :param str odom_frame_id: Parent frame ID of the odom messages.
        :param str odom_topic: Odometry topic to publish.
        :param str cmd_vel_topic: If not None, the computed linear and angular velocities are published on this topic
                                  as cmd_vel messages.
        :param str cmd_vel_frame_id: If not None, the published cmd_vel is a stamped message with this frame_id.
        :param float min_dt: If the computed dt is smaller than this, the joint state message is ignored for odometry.
        :param float heading_change_coef: This is a hack. The angular distance added to yaw is multiplied by this number
                                          after being computed from angular velocity and dt.
        :param list twist_covariance: Covariance of twist used in the odometry message. A 36-element list of floats.
        :param set add_tags: Tags to be added to the generated odometry messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(RecomputeAckermannOdometry, self).__init__(
            include_topics=(joint_state_topic,), include_types=(JointState._type,), *args, **kwargs)  # noqa
        self._wheel_radius = wheel_radius
        self._wheel_separation = wheel_separation
        self._traction_joint = traction_joint
        self._steering_joint = steering_joint
        self._frame_id = frame_id
        self._odom_frame_id = odom_frame_id
        self._odom_topic = odom_topic
        self._cmd_vel_topic = cmd_vel_topic
        self._cmd_vel_frame_id = cmd_vel_frame_id
        self._min_dt = min_dt
        self._heading_change_coef = heading_change_coef
        self._twist_covariance = list(twist_covariance) if twist_covariance is not None else [0.0] * 36
        self._add_tags = add_tags

        if len(self._twist_covariance) != 36:
            raise RuntimeError("Twist covariance should be a 36 element list, but %i elements were given!" %
                               (len(self._twist_covariance,)))

        self._connection_header = create_connection_header(self._odom_topic, Odometry, False)
        if self._cmd_vel_topic:
            self._connection_header_cmd_vel = create_connection_header(
                self._cmd_vel_topic, TwistStamped if self._cmd_vel_frame_id else Twist, False)

        self._x = 0.0
        self._y = 0.0
        self._yaw = 0.0
        self._lin_vel = 0.0
        self._ang_vel = 0.0

        self._steering_angle = 0.0
        self._last_time = None

    def filter(self, topic, msg, stamp, header, tags):
        result = [(topic, msg, stamp, header, tags)]

        if self._steering_joint in msg.name:
            i = msg.name.index(self._steering_joint)
            self._steering_angle = msg.position[i]
        if self._traction_joint in msg.name:
            if self._last_time is None:
                self._last_time = msg.header.stamp
                result.append(self._construct_odom(msg.header.stamp, stamp, tags))
            else:
                i = msg.name.index(self._traction_joint)
                wheel_ang_vel = msg.velocity[i]
                dt = (msg.header.stamp - self._last_time).to_sec()
                if dt > self._min_dt:
                    self._last_time = msg.header.stamp

                    self._lin_vel = wheel_ang_vel * self._wheel_radius
                    linear = self._lin_vel * dt

                    steering_angle = normalize_angle(self._steering_angle)
                    angular = math.tan(steering_angle) * linear / self._wheel_separation
                    self._ang_vel = angular / dt
                    angular *= self._heading_change_coef

                    direction = self._yaw + angular * 0.5
                    self._x += linear * math.cos(direction)
                    self._y += linear * math.sin(direction)
                    self._yaw += angular

                    result.append(self._construct_odom(msg.header.stamp, stamp, tags))
                    if self._cmd_vel_topic:
                        result.append(self._construct_cmd_vel(msg.header.stamp, stamp, tags))

        return result

    def _construct_odom(self, header_stamp, receive_stamp, tags):
        msg = Odometry()
        msg.header.frame_id = self._odom_frame_id
        msg.header.stamp = header_stamp
        msg.child_frame_id = self._frame_id
        msg.pose.pose.position.x = self._x
        msg.pose.pose.position.y = self._y
        msg.pose.pose.orientation = quat_msg_from_rpy(0, 0, self._yaw)
        # TODO pose covariance
        msg.twist.twist.linear.x = self._lin_vel
        msg.twist.twist.angular.z = self._ang_vel
        msg.twist.covariance = self._twist_covariance
        odom_tags = tags_for_generated_msg(tags)
        if self._add_tags:
            odom_tags = odom_tags.union(self._add_tags)
        return self._odom_topic, msg, receive_stamp, self._connection_header, odom_tags

    def _construct_cmd_vel(self, header_stamp, receive_stamp, tags):
        twist = Twist()
        twist.linear.x = self._lin_vel
        twist.angular.z = self._ang_vel

        if self._cmd_vel_frame_id:
            msg = TwistStamped()
            msg.header.frame_id = self._cmd_vel_frame_id
            msg.header.stamp = header_stamp
            msg.twist = twist
        else:
            msg = twist

        cmd_vel_tags = tags_for_generated_msg(tags)
        if self._add_tags:
            cmd_vel_tags = cmd_vel_tags.union(self._add_tags)

        return self._cmd_vel_topic, msg, receive_stamp, self._connection_header_cmd_vel, cmd_vel_tags

    def _str_params(self):
        parts = []
        parts.append('wheel_radius=' + str(self._wheel_radius))
        parts.append('wheel_separation=' + str(self._wheel_separation))
        parts.append('traction_joint=' + self._traction_joint)
        parts.append('steering_joint=' + self._steering_joint)
        parts.append('frame_id=' + self._frame_id)
        parts.append('odom_frame_id=' + self._odom_frame_id)
        parts.append('odom_topic=' + self._odom_topic)
        if self._cmd_vel_frame_id is not None:
            parts.append('cmd_vel_frame_id=' + self._cmd_vel_frame_id)
        if self._cmd_vel_topic is not None:
            parts.append('cmd_vel_topic=' + self._cmd_vel_topic)
        if self._twist_covariance != [0.0] * 36:
            parts.append('twist_covariance=' + ','.join(map(str, self._twist_covariance)))
        parts.append('min_dt=' + to_str(self._min_dt).rstrip('0'))

        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class EstimateJointStatesFromMotionModelAckermann(DeserializedMessageFilter):
    """Take a linear and angular velocity source and compute the corresponding joint states via Ackermann-style
    odometry."""

    def __init__(self, wheel_radius, wheel_separation, traction_joints, steering_joints, linear_velocity_topic,
                 linear_velocity_field, angular_velocity_topic, angular_velocity_field,
                 joint_state_topic="joint_states", frame_id="base_link", separate_messages=False,
                 min_linear_velocity=0.01, add_tags=None, *args, **kwargs):
        """
        :param float wheel_radius: Radius of the traction wheel [m].
        :param float wheel_separation: Separation of the steering and traction axles (wheelbase) [m].
        :param str traction_joints: Joints whose velocity will be used as traction velocity.
        :param str steering_joints: Joints whose position will be used as steering angle.
        :param str linear_velocity_topic: Topic from which linear velocity should be read.
        :param str linear_velocity_field: Field in the linear velocity topic that specifies linear speed.
        :param str angular_velocity_topic: Topic from which angular velocity should be read.
        :param str angular_velocity_field: Field in the angular velocity topic that specifies angular speed.
        :param str joint_state_topic: Topic with joint states.
        :param str frame_id: Frame ID of the joint states messages.
        :param bool separate_messages: If true, the steering and traction joints will each be published in a separate
                                       message (on the same topic).
        :param float min_linear_velocity: Minimum linear velocity to consider the robot non-static. If static, the
                                          steering angle doesn't change (to avoid division by close to zero).
        :param set add_tags: Tags to be added to the generated odometry messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(EstimateJointStatesFromMotionModelAckermann, self).__init__(
            include_topics=(linear_velocity_topic, angular_velocity_topic), *args, **kwargs)  # noqa
        self._wheel_radius = wheel_radius
        self._wheel_separation = wheel_separation
        self._traction_joints = list(traction_joints)
        self._steering_joints = list(steering_joints)
        self._linear_velocity_topic = TopicSet([linear_velocity_topic])
        self._linear_velocity_field = linear_velocity_field
        self._angular_velocity_topic = TopicSet([angular_velocity_topic])
        self._angular_velocity_field = angular_velocity_field
        self._joint_state_topic = joint_state_topic
        self._frame_id = frame_id
        self._separate_messages = separate_messages
        self._min_linear_velocity = min_linear_velocity
        self._add_tags = add_tags

        self._connection_header = create_connection_header(self._joint_state_topic, JointState, False)

        self._last_ang_vel = None
        self._last_steering_angle = 0.0

    def filter(self, topic, msg, stamp, header, tags):
        result = [(topic, msg, stamp, header, tags)]

        if topic in self._angular_velocity_topic:
            ang_vel = float(self._get_field(msg, self._angular_velocity_field))
            self._last_ang_vel = ang_vel
            return result

        if self._last_ang_vel is None:
            return result

        lin_vel = float(self._get_field(msg, self._linear_velocity_field))

        wheel_ang_vel = lin_vel / self._wheel_radius
        if abs(lin_vel) > self._min_linear_velocity:
            steering_angle = math.atan(self._last_ang_vel * self._wheel_separation / lin_vel)
        else:
            steering_angle = self._last_steering_angle
        self._last_steering_angle = steering_angle

        msg_stamp = msg.header.stamp if "header" in msg.__slots__ else stamp

        gen_tags = tags_for_generated_msg(tags)
        if self._add_tags:
            gen_tags = gen_tags.union(self._add_tags)

        traction_msg = JointState()
        traction_msg.header.stamp = msg_stamp
        traction_msg.header.frame_id = self._frame_id
        traction_msg.name = list(self._traction_joints)
        traction_msg.position = [float('nan')] * len(self._traction_joints)
        traction_msg.velocity = [wheel_ang_vel] * len(self._traction_joints)
        traction_msg.effort = [float('nan')] * len(self._traction_joints)

        steering_msg = JointState()
        steering_msg.header.stamp = msg_stamp
        steering_msg.header.frame_id = self._frame_id
        steering_msg.name = list(self._steering_joints)
        steering_msg.position = [normalize_angle(steering_angle)] * len(self._steering_joints)
        steering_msg.velocity = [float('nan')] * len(self._steering_joints)
        steering_msg.effort = [float('nan')] * len(self._steering_joints)

        if self._separate_messages:
            traction_tags = gen_tags.union({"traction_joint"})
            result += [(self._joint_state_topic, traction_msg, stamp, self._connection_header, traction_tags)]
            steering_tags = gen_tags.union({"steering_joint"})
            result += [(self._joint_state_topic, steering_msg, stamp, self._connection_header, steering_tags)]
        else:
            joints_msg = JointState()
            joints_msg.header = traction_msg.header
            joints_msg.name = traction_msg.name + steering_msg.name
            joints_msg.position = traction_msg.position + steering_msg.position
            joints_msg.velocity = traction_msg.velocity + steering_msg.velocity
            joints_msg.effort = traction_msg.effort + steering_msg.effort
            result += [(self._joint_state_topic, joints_msg, stamp, self._connection_header, gen_tags)]

        return result

    def _get_field(self, msg, field):
        if '.' not in field:
            if field in msg.__slots__:
                return getattr(msg, field)
            else:
                raise RuntimeError("Invalid field '%s'" % (field,))
        else:
            attr, rest = field.split('.', maxsplit=1)
            return self._get_field(getattr(msg, attr), rest)

    def _str_params(self):
        parts = []
        parts.append('wheel_radius=' + str(self._wheel_radius))
        parts.append('wheel_separation=' + str(self._wheel_separation))
        parts.append('traction_joints=' + to_str(self._traction_joints))
        parts.append('steering_joints=' + to_str(self._steering_joints))
        parts.append('linear_velocity_topic=' + str(self._linear_velocity_topic))
        parts.append('linear_velocity_field=' + self._linear_velocity_field)
        parts.append('angular_velocity_topic=' + str(self._angular_velocity_topic))
        parts.append('angular_velocity_field=' + self._angular_velocity_field)
        parts.append('joint_state=' + self._joint_state_topic)
        parts.append('frame_id=' + self._frame_id)
        if self._separate_messages:
            parts.append('separate_messages')

        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class OdomToTf(DeserializedMessageFilter):
    """Convert odometry to TF."""

    def __init__(self, odom_topic, tf_topic="tf", add_tags=None, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Optional[Set[STRING_TYPE]], Any, Any) -> None
        """
        :param odom_topic: Odometry topic to convert.
        :param tf_topic: The TF topic.
        :param add_tags: Tags to be added to the generated TF messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(OdomToTf, self).__init__(
            include_topics=(odom_topic,), include_types=(Odometry._type,), *args, **kwargs)  # noqa

        self._odom_topic = odom_topic
        self._tf_topic = tf_topic
        self._add_tags = add_tags
        self._connection_header = create_connection_header(tf_topic, TFMessage, False)

    def filter(self, topic, msg, stamp, header, tags):
        odom_tf = TransformStamped()
        odom_tf.header.frame_id = msg.header.frame_id
        odom_tf.header.stamp = msg.header.stamp
        odom_tf.child_frame_id = msg.child_frame_id
        odom_tf.transform.translation = msg.pose.pose.position
        odom_tf.transform.rotation = msg.pose.pose.orientation

        tf_msg = TFMessage()
        tf_msg.transforms.append(odom_tf)

        tf_tags = tags_for_generated_msg(tags)
        if self._add_tags:
            tf_tags = tf_tags.union(self._add_tags)

        return [
            (topic, msg, stamp, header, tags),
            (self._tf_topic, tf_msg, stamp, self._connection_header, tf_tags),
        ]

    def _str_params(self):
        parts = []
        parts.append('odom_topic=' + self._odom_topic)
        if self._tf_topic != "tf":
            parts.append('tf_topic=' + self._tf_topic)

        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


CalibrationYAMLType = Union[STRING_TYPE, Tuple[STRING_TYPE, STRING_TYPE]]


class FixCameraCalibration(DeserializedMessageFilter):
    """Adjust some camera calibrations."""

    def __init__(self, calibrations=None, warn_size_change=True, *args, **kwargs):
        # type: (Optional[Dict[STRING_TYPE, CalibrationYAMLType]], bool, Any, Any) -> None  # noqa
        """
        :param calibrations: Dictionary with camera_info topic names as keys and YAML files with calibrations as values.
                             If the calibration is in kalibr format and the camera is not cam0, then pass a tuple
                             (YAML file, cam_name) instead of just directly YAML file
        :param warn_size_change: If True (default), warn if the fixed camera info has different width or height than the
                                 original.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(FixCameraCalibration, self).__init__(include_topics=list(calibrations.keys()) if calibrations else None,
                                                   include_types=(CameraInfo._type,), *args, **kwargs)  # noqa

        self.warn_size_change = warn_size_change
        self._calibrations = {}
        self._cam_names = {}

        if calibrations is None:
            return
        for camera_info, calib_file in calibrations.items():
            calib_file, cam_name = (calib_file, "cam0") if isinstance(calib_file, STRING_TYPE) else calib_file

            calib = None
            try:
                calib = self.interpret_calibration(self.resolve_file(calib_file), cam_name)
            except AssertionError as e:
                print("Could not interpret the given calibration file for camera %s: %s." % (camera_info, str(e)))
                continue
            if calib is None:
                print("Could not interpret the given calibration file for camera %s." % (camera_info,))
                continue

            self._calibrations[camera_info] = calib

    def interpret_calibration(self, calib_file, cam_name):
        msg = None
        try:
            if rosconsole_set_logger_level("ros.camera_calibration_parsers", RosconsoleLevel.FATAL):
                rosconsole_notifyLoggerLevelsChanged()
            _, msg = readCalibration(calib_file)  # camera_info_manager format
            print("Interpreted calibration file %s using camera_info_manager format." % (calib_file,))

        except Exception:
            pass

        if msg is not None:
            return msg

        with open(calib_file, 'r') as f:
            calib_data = yaml.safe_load(f)

        if "Intrinsics" in calib_data or "CalibParam" in calib_data:  # ikalibr format
            data = None
            cam_type = None

            if "Intrinsics" in calib_data:  # single-camera intrinsics file
                data = calib_data["Intrinsics"]["ptr_wrapper"]["data"]
                cam_type = calib_data["Intrinsics"]["polymorphic_name"]
            else:  # ikalibr_param.yaml file with the overall result of calibration
                data_all = calib_data["CalibParam"]["INTRI"]["Camera"]
                polymorphic_map = ["pinhole_brown_t2"]
                for item in data_all:
                    if "polymorphic_name" in item["value"]:
                        polymorphic_map.append(item["value"]["polymorphic_name"])
                        cam_type = item["value"]["polymorphic_name"]
                    else:
                        polymorphic_id = int(item["value"]["polymorphic_id"])
                        if len(polymorphic_map) >= polymorphic_id - 1:
                            cam_type = polymorphic_map[polymorphic_id]
                    if item["key"] == cam_name:
                        data = item["value"]["ptr_wrapper"]["data"]
                        break

            if data is None or cam_type is None:
                raise RuntimeError("Could not find camera %s in calibration file %s." % (cam_name, calib_file))

            w = data["img_width"]
            h = data["img_height"]

            msg = CameraInfo()
            msg.width = w
            msg.height = h
            msg.D = data["disto_param"]
            msg.distortion_model = \
                EQUIDISTANT if cam_type == "pinhole_fisheye" else (PLUMB_BOB if len(msg.D) < 6 else RATIONAL_POLYNOMIAL)
            msg.K = [
                data["focal_length"][0], 0.0, data["principal_point"][0],
                0.0, data["focal_length"][1], data["principal_point"][1],
                0.0, 0.0, 1.0,
            ]
            msg.R = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]

            K = np.array(msg.K).reshape((3, 3))
            R = np.array(msg.R).reshape((3, 3))
            D = np.array(msg.D)
            if msg.distortion_model != EQUIDISTANT:
                P, _ = cv2.getOptimalNewCameraMatrix(K, D, (w, h), 0.0)
            else:
                P = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(K, D, (w, h), R, balance=0.0)
            msg.P = [
                P[0, 0], P[0, 1], P[0, 2], 0.0,
                P[1, 0], P[1, 1], P[1, 2], 0.0,
                P[2, 0], P[2, 1], P[2, 2], 0.0,
            ]
            print("Interpreted calibration file %s using ikalibr format." % (calib_file,))
        elif len(calib_data) > 0 and "cam_overlaps" in calib_data[list(calib_data.keys())[0]]:  # kalibr format
            data = calib_data[cam_name]
            w, h = data["resolution"]
            cam_type = data["distortion_model"]

            msg = CameraInfo()
            msg.width = w
            msg.height = h
            msg.D = data["distortion_coeffs"]
            msg.distortion_model = \
                EQUIDISTANT if cam_type == "equidistant" else (PLUMB_BOB if len(msg.D) < 6 else RATIONAL_POLYNOMIAL)
            msg.K = [
                data["intrinsics"][0], 0.0, data["intrinsics"][2],
                0.0, data["intrinsics"][1], data["intrinsics"][3],
                0.0, 0.0, 1.0,
            ]
            msg.R = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]

            K = np.array(msg.K).reshape((3, 3))
            R = np.array(msg.R).reshape((3, 3))  # TODO T_cn_cnm1
            D = np.array(msg.D)
            if msg.distortion_model != EQUIDISTANT:
                P, _ = cv2.getOptimalNewCameraMatrix(K, D, (w, h), 0.0)
            else:
                P = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(K, D, (w, h), R, balance=0.0)
            msg.P = [
                P[0, 0], P[0, 1], P[0, 2], 0.0,
                P[1, 0], P[1, 1], P[1, 2], 0.0,
                P[2, 0], P[2, 1], P[2, 2], 0.0,
            ]
            print("Interpreted calibration file %s using kalibr format." % (calib_file,))
        else:
            raise RuntimeError("Unsupported camera calibration format: " + calib_file)

        return msg

    def filter(self, topic, msg, stamp, header, tags):

        if topic in self._calibrations:
            calib = copy.deepcopy(self._calibrations[topic])
            if self.warn_size_change and (calib.width != msg.width or calib.height != msg.height):
                print("Fixed camera info size (%i, %i) differs from original camera info (%i, %i)" % (
                    calib.width, calib.height, msg.width, msg.height))
            msg.distortion_model = calib.distortion_model
            msg.D = calib.D
            msg.R = calib.R
            msg.K = calib.K
            msg.P = calib.P
            tags.add(MessageTags.CHANGED)

        return topic, msg, stamp, header, tags

    def _str_params(self):
        parts = []
        parts.append('calibrations=' + ",".join(self._calibrations.keys()))
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class StaticImageMask(RawMessageFilter):
    """Create static image masks."""

    def __init__(self, source_topic, mask_topics, mask_file, format_field, add_tags=None, *args, **kwargs):
        """
        :param str source_topic: Topic which triggers generation of the mask images. This is usually the camera info or
                                 the image topic.
        :param list mask_topics: Topics for which the mask will be generated (if multiple topics are specified, the mask
                                 is copied to all of them). Include the `/compressed` suffix to each topic.
        :param str mask_file: Path to the file with the mask. It should be an imread()-compatible compressed image.
        :param str format_field: The contents of the `format` field in the mask image. It should follow the pattern
                                 '{raw_pixel_format}; {image_format} compressed {compressed_pixel_format}'. If the
                                 compression uses the same pixel format as the raw image, leave out
                                 `compressed_pixel_format`, but keep the preceding space.
        :param set add_tags: Tags to be added to the modified TF messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(StaticImageMask, self).__init__(include_topics=(source_topic,), *args, **kwargs)
        self._mask_topics = mask_topics
        self._mask_file = mask_file
        self._mask_msg = CompressedImage()
        self._mask_msg.format = format_field
        self._add_tags = add_tags

        self._conn_headers = {}
        for t in mask_topics:
            self._conn_headers[t] = create_connection_header(t, CompressedImage)

    def on_filtering_start(self):
        mask_file = self.resolve_file(self._mask_file)
        if not os.path.exists(mask_file):
            raise RuntimeError("File " + mask_file + " does not exist.")
        with open(mask_file, 'rb') as f:
            self._mask_msg.data = f.read()
        print("Loaded camera mask", mask_file)

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        msg_header = deserialize_header(data, pytype)
        if msg_header is None:
            raise RuntimeError("Message type doesn't have header: " + datatype)

        self._mask_msg.header = msg_header

        result = [(topic, datatype, data, md5sum, pytype, stamp, header, tags)]

        gen_tags = tags_for_generated_msg(tags, self._add_tags)

        for t in self._mask_topics:
            result.append((t, self._mask_msg, stamp, self._conn_headers[t], gen_tags))

        return result

    def _str_params(self):
        parts = []
        parts.append('mask_topics=%r' % (self._mask_topics,))
        parts.append('mask_file=%s' % (self._mask_file,))
        parts.append('format_field=%s' % (self._mask_msg.format,))
        parent_params = super(StaticImageMask, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class FixStaticTF(DeserializedMessageFilterWithTF):
    """Adjust some static transforms."""

    def __init__(self, transforms=None, add_tags=None, *args, **kwargs):
        # type: (Optional[Sequence[Dict[STRING_TYPE, STRING_TYPE]]], Optional[Set[STRING_TYPE]], Any, Any) -> None
        """
        :param transforms: The new transforms. The dicts have to contain keys "frame_id", "child_frame_id", "transform".
                           The transform has to be a 6-tuple (x, y, z, roll, pitch, yaw)
                           or 7-tuple (x, y, z, qx, qy, qz, qw).
        :param add_tags: Tags to be added to the modified TF messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(FixStaticTF, self).__init__(
            include_topics=["tf_static"], include_types=(TFMessage._type,), *args, **kwargs)  # noqa

        self._transforms = dict()
        self._add_tags = add_tags

        for t in (transforms if transforms is not None else list()):
            frame_id = t["frame_id"]
            child_frame_id = t["child_frame_id"]
            data = t["transform"]
            if len(data) == 7:
                transform = Transform(Vector3(*data[:3]), Quaternion(*data[3:]))
            elif len(data) == 6:
                transform = Transform(Vector3(*data[:3]), quat_msg_from_rpy(*data[3:]))
            else:
                raise RuntimeError("'transform' has to be either a 6-tuple or 7-tuple.")

            self._transforms[(frame_id, child_frame_id)] = transform

    def filter(self, topic, msg, stamp, header, tags):
        for transform in msg.transforms:
            key = (transform.header.frame_id, transform.child_frame_id)
            if key in self._transforms:
                transform.transform = self._transforms[key]
                print("Adjusted transform %s->%s." % (key[0], key[1]))
                tags.add(MessageTags.CHANGED)
                if self._add_tags:
                    tags = tags.union(self._add_tags)

        return topic, msg, stamp, header, tags

    def _str_params(self):
        parts = []
        parts.append('transforms=' + repr(list(self._transforms.keys())))
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class AddStaticTF(DeserializedMessageFilterWithTF):
    """Add static transforms to the first message on /tf_static."""

    def __init__(self, transforms=None, add_tags=None, *args, **kwargs):
        # type: (Optional[Sequence[Dict[STRING_TYPE, STRING_TYPE]]], Optional[Set[STRING_TYPE]], Any, Any) -> None
        """
        :param transforms: The new transforms. The dicts have to contain keys "frame_id", "child_frame_id", "transform".
                           The transform has to be a 6-tuple (x, y, z, roll, pitch, yaw)
                           or 7-tuple (x, y, z, qx, qy, qz, qw).
        :param add_tags: Tags to be added to the modified TF messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(AddStaticTF, self).__init__(
            include_topics=["tf_static"], include_types=(TFMessage._type,), *args, **kwargs)  # noqa

        self._transforms = list()
        self._add_tags = add_tags
        self._added = False

        for t in (transforms if transforms is not None else list()):
            frame_id = t["frame_id"]
            child_frame_id = t["child_frame_id"]
            data = t["transform"]
            if len(data) == 7:
                transform = Transform(Vector3(*data[:3]), Quaternion(*data[3:]))
            elif len(data) == 6:
                transform = Transform(Vector3(*data[:3]), quat_msg_from_rpy(*data[3:]))
            else:
                raise RuntimeError("'transform' has to be either a 6-tuple or 7-tuple.")

            tf_stamped = TransformStamped()
            tf_stamped.header.frame_id = frame_id
            tf_stamped.child_frame_id = child_frame_id
            tf_stamped.transform = transform
            self._transforms.append(tf_stamped)

    def consider_message(self, topic, datatype, stamp, header, tags):
        if self._added:
            return False
        return super(AddStaticTF, self).consider_message(topic, datatype, stamp, header, tags)

    def filter(self, topic, msg, stamp, header, tags):
        msg_stamp = min(t.header.stamp for t in msg.transforms)
        for tf_stamped in self._transforms:
            tf_stamped.header.stamp = msg_stamp
            msg.transforms.append(tf_stamped)
            print("Added transform %s->%s." % (tf_stamped.header.frame_id, tf_stamped.child_frame_id))
        self._added = True

        return topic, msg, stamp, header, tags_for_changed_msg(tags, self._add_tags)

    def on_filtering_end(self):
        if not self._added:
            print("AddStaticTF did not receive any message on /tf_static, so no static TF was added.", file=sys.stderr)
        super(AddStaticTF, self).on_filtering_end()

    def reset(self):
        self._added = False
        super(AddStaticTF, self).reset()

    def _str_params(self):
        parts = []
        parts.append('transforms=' + repr([(t.header.frame_id, t.child_frame_id) for t in self._transforms]))
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class ExportTFTrajectory(DeserializedMessageFilterWithTF):
    """Export TF trajectory to a CSV file."""

    def __init__(self, odom_frame_id, child_frame_id, csv_file, trigger_frame_id=None, trigger_topic=None,
                 max_frequency=None, tf_buffer_length=None, *args, **kwargs):
        # type: (STR, STR, STR, Optional[STR], Optional[STR], Optional[float], Optional[float], Any, Any) -> None
        """
        :param odom_frame_id: Frame ID of the odometry frame.
        :param child_frame_id: Frame ID of the tracked frame.
        :param csv_file: Path to the CSV file to store the trajectory.
        :param trigger_frame_id: Frame ID of the dynamic TF child frame whose update triggers trajectory point creation.
        :param trigger_topic: If set, trajectory points will be generated when this message is received instead of being
                              triggered by trigger_frame_id.
        :param max_frequency: The maximum frequency on which the transform should be published.
        :param tf_buffer_length: Length of the TF buffer (in seconds).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        include_topics = ["/tf", "/tf_static"]
        include_types = (TFMessage._type,)

        if trigger_topic is not None:
            include_topics.append(trigger_topic)
            include_types = None  # we don't know the trigger topic type, so we need to accept all

        super(ExportTFTrajectory, self).__init__(
            include_topics=include_topics, include_types=include_types, *args, **kwargs)  # noqa

        self._odom_frame_id = odom_frame_id
        self._child_frame_id = child_frame_id
        self._trigger_frame_id = trigger_frame_id if trigger_frame_id is not None else child_frame_id
        self._trigger_topic = trigger_topic
        if trigger_topic is not None:
            self._trigger_topics = TopicSet((trigger_topic,))
            self._trigger_frame_id = None
        self._min_dt = rospy.Duration(1.0 / float(max_frequency)) if max_frequency is not None else None
        if tf_buffer_length is not None:
            self._tf_buffer = BufferCore(rospy.Duration.from_sec(float(tf_buffer_length)))
        else:
            self._tf_buffer = BufferCore()

        self._csv_file_path = csv_file
        self._csv_rows = []

        self._last_traj_time = None

    def on_filtering_start(self):
        self._csv_rows = []

    def on_filtering_end(self):
        csv_file = self.resolve_file(self._csv_file_path)
        with open(csv_file, 'w', newline='') as f:
            fields = ("stamp", "tx", "ty", "tz", "rx", "ry", "rz", "rw")
            writer = csv.DictWriter(f, fieldnames=fields)
            writer.writeheader()
            for row in self._csv_rows:
                writer.writerow(dict(zip(fields, row)))
        print("Saved CSV with", len(self._csv_rows), "rows:", csv_file)

    def filter(self, topic, msg, stamp, header, tags):
        if topic in self._tf_static_topics:
            for t in msg.transforms:
                # converting to standard type message avoids warnings from BufferCore, but is not strictly needed
                self._tf_buffer.set_transform_static(bag_msg_type_to_standard_type(t), "bag")
        elif topic in self._tf_topics:
            for t in msg.transforms:
                self._tf_buffer.set_transform(bag_msg_type_to_standard_type(t), "bag")

        if MessageTags.EXTRA_TIME_RANGE in tags:
            return topic, msg, stamp, header, tags

        # Check if the trigger frame is among those changed in this TF message
        trigger_stamp = None
        if self._trigger_topic is not None and topic in self._trigger_topics:
            trigger_stamp = msg.header.stamp
        elif msg._type == TFMessage._type:
            for t in msg.transforms:
                if t.child_frame_id == self._trigger_frame_id:
                    trigger_stamp = t.header.stamp
                    break
            else:
                # If trigger frame is not updated by this message, stop here
                return topic, msg, stamp, header, tags
        else:
            raise RuntimeError("Unexpected message received on topic " + topic + " with type " + msg._type)

        if self._min_dt is not None and self._last_traj_time is not None and \
                self._last_traj_time + self._min_dt > trigger_stamp:
            return topic, msg, stamp, header, tags

        try:
            tf = self._tf_buffer.lookup_transform_core(self._odom_frame_id, self._child_frame_id, trigger_stamp)
        except TransformException as e:
            print("Transform error: " + str(e), file=sys.stderr)
            return topic, msg, stamp, header, tags
        self._last_traj_time = trigger_stamp

        tr = tf.transform.translation
        rot = tf.transform.rotation
        csv_row = (trigger_stamp.to_sec(), tr.x, tr.y, tr.z, rot.x, rot.y, rot.z, rot.w)
        self._csv_rows.append(csv_row)

        return topic, msg, stamp, header, tags

    def reset(self):
        super(ExportTFTrajectory, self).reset()
        self._last_traj_time = None
        self._csv_rows = []
        self._tf_buffer.clear()

    def _str_params(self):
        parts = []
        parts.append('odom_frame_id=' + self._odom_frame_id)
        parts.append('child_frame_id=' + self._child_frame_id)
        if self._trigger_topic is not None:
            parts.append('trigger_topic=' + self._trigger_topic)
        else:
            parts.append('trigger_frame_id=' + self._trigger_frame_id)
        parts.append('csv_file=' + self.resolve_file(self._csv_file_path))
        parent_params = self._default_str_params(include_topics=False, include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class ExportFinalTFTreeToYAML(DeserializedMessageFilterWithTF):
    """Export TF tree at the end of the bag as a YAML file."""

    def __init__(self, yaml_file, static_only=True, tf_buffer_length=None, transforms=None, *args, **kwargs):
        # type: (STR, bool, Optional[float], Optional[Sequence[Tuple[STR, STR]]], Any, Any) -> None
        """
        :param yaml_file: Path to the YAML file to store the tree.
        :param static_only: If true, only read static transforms.
        :param tf_buffer_length: Length of the TF buffer (in seconds).
        :param transforms: If set, only export the given transforms. This is a list of pairs (parent, child).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        include_topics = ["/tf_static"] if static_only else ["/tf", "/tf_static"]
        include_types = (TFMessage._type,)

        super(ExportFinalTFTreeToYAML, self).__init__(
            include_topics=include_topics, include_types=include_types, *args, **kwargs)

        self._yaml_file_path = yaml_file
        self._static_only = static_only
        self._transforms = transforms

        if tf_buffer_length is not None:
            self._tf_buffer = BufferCore(rospy.Duration.from_sec(float(tf_buffer_length)))
        else:
            self._tf_buffer = BufferCore()

    def on_filtering_end(self):
        frames = []
        if self._transforms is not None:
            frames = self._transforms
        else:
            all_frames = yaml.safe_load(self._tf_buffer.all_frames_as_yaml())
            for child_frame in all_frames:
                parent_frame = all_frames[child_frame]['parent']
                frames.append((parent_frame, child_frame))

        data = {}
        for parent, child in frames:
            try:
                tf = self._tf_buffer.lookup_transform_core(parent, child, rospy.Time(0))
            except TransformException as e:
                print('Transform error: ' + str(e), file=sys.stderr)
                continue
            t = numpify(tf.transform)
            t[np.abs(t) < 1e-8] = 0.0
            t_key = 'T_{parent}__{child}'.format(parent=parent, child=child)
            t_data = {
                'rows': 4,
                'cols': 4,
                'data': list(map(float, t.reshape(-1).tolist())),
            }
            data[t_key] = t_data

        yaml_file = self.resolve_file(self._yaml_file_path)
        with open(yaml_file, 'w') as f:
            yaml.safe_dump(data, f)
        print("Saved YAML file with", len(data), "keys:", yaml_file)

    def filter(self, topic, msg, stamp, header, tags):
        if topic in self._tf_static_topics:
            for t in msg.transforms:
                # converting to standard type message avoids warnings from BufferCore, but is not strictly needed
                self._tf_buffer.set_transform_static(bag_msg_type_to_standard_type(t), "bag")
        elif topic in self._tf_topics:
            for t in msg.transforms:
                self._tf_buffer.set_transform(bag_msg_type_to_standard_type(t), "bag")

        return topic, msg, stamp, header, tags

    def reset(self):
        super(ExportFinalTFTreeToYAML, self).reset()
        self._tf_buffer.clear()

    def _str_params(self):
        parts = []
        parts.append('yaml=' + self.resolve_file(self._yaml_file_path))
        if self._static_only:
            parts.append('static_only')
        if self._transforms is not None:
            parts.append(items_to_str(self._transforms, '->'))
        parent_params = self._default_str_params(include_topics=False, include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class ExportCmdVelToCSV(MessageToCSVExporterBase):
    """Export cmd_vel commands to a CSV file."""

    def __init__(self, topic, csv_file, max_frequency=None, frequency_from_header_stamp=False, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Optional[float], bool, Any, Any) -> None
        """
        :param topic: The topic to export (`geometry_msgs/Twist` or `geometry_msgs/TwistStamped` type).
        :param csv_file: Path to the CSV file to store the trajectory.
        :param max_frequency: The maximum frequency on which the transform should be published.
        :param frequency_from_header_stamp: If True, the header stamp will be used as the timestamp for frequency
                                            checking. If False, receive stamp is used. Only works with `TwistStamped`.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExportCmdVelToCSV, self).__init__(
            csv_file=csv_file, max_frequency=max_frequency, frequency_from_header_stamp=frequency_from_header_stamp,
            include_topics=(topic,), include_types=(Twist._type, TwistStamped._type), *args, **kwargs)  # noqa

    def _get_fields(self):
        return "stamp", "linear", "angular"

    def _msg_to_csv_row(self, topic, msg, stamp, header, tags):
        msg_stamp = msg.header.stamp if msg._type == TwistStamped._type else stamp
        twist = msg if msg._type == Twist._type else msg.twist
        return msg_stamp.to_sec(), twist.linear.x, twist.angular.z


class ExportCameraInfoToYAML(MessageToYAMLExporterBase):
    """Export the first camera info on a topic to a YAML file."""

    def __init__(self, topic, yaml_file, yaml_dump_options=None, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Optional[Dict[STRING_TYPE, Any]], Any, Any) -> None
        """
        :param topic: The topic to export (`sensor_msgs/CameraInfo`).
        :param yaml_file: Path to the YAML file to store the camera info.
        :param yaml_dump_options: Optional options passed to yaml.safe_dump() as kwargs.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExportCameraInfoToYAML, self).__init__(
            yaml_file=yaml_file, yaml_dump_options=yaml_dump_options,
            include_topics=(topic,), include_types=(CameraInfo._type,), *args, **kwargs)  # noqa

        self._finished = False

    def consider_message(self, topic, datatype, stamp, header, tags):
        if self._finished:
            return False
        return super(ExportCameraInfoToYAML, self).consider_message(topic, datatype, stamp, header, tags)

    def _init_data(self):
        return {}

    def _append_msg_data(self, topic, msg, stamp, header, tags):
        self._finished = True

        self._data['image_width'] = msg.width
        self._data['image_height'] = msg.height
        self._data['camera_name'] = msg.header.frame_id
        self._data['camera_matrix'] = {
            'rows': 3,
            'cols': 3,
            'data': list(msg.K),
        }
        self._data['distortion_model'] = msg.distortion_model
        self._data['distortion_coefficients'] = {
            'rows': 1,
            'cols': len(msg.D),
            'data': list(msg.D),
        }

    def _write_data_to_file(self, out_file):
        if len(self._data) == 0:
            print("YAML", out_file, "has no data, not writing.")
            return
        super(ExportCameraInfoToYAML, self)._write_data_to_file(out_file)

    def reset(self):
        self._finished = False
        super(ExportCameraInfoToYAML, self).reset()


class ExportImuToCSV(MessageToCSVExporterBase):
    """Export Imu messages to a CSV file."""

    def __init__(self, topic, csv_file, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Any, Any) -> None
        """
        :param topic: The topic to export (`sensor_msgs/Imu` type).
        :param csv_file: Path to the CSV file to store the trajectory.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExportImuToCSV, self).__init__(
            csv_file=csv_file, max_frequency=None, frequency_from_header_stamp=False,
            include_topics=(topic,), include_types=(Imu._type,), *args, **kwargs)  # noqa

    def _get_fields(self):
        return "stamp", "qx", "qy", "qz", "qw", "wx", "wy", "wz", "ax", "ay", "az", \
                "cq0", "cq1", "cq2", "cq3", "cq4", "cq5", "cq6", "cq7", "cq8", \
                "cw0", "cw1", "cw2", "cw3", "cw4", "cw5", "cw6", "cw7", "cw8", \
                "ca0", "ca1", "ca2", "ca3", "ca4", "ca5", "ca6", "ca7", "ca8"

    def _msg_to_csv_row(self, topic, msg, stamp, header, tags):
        # type: (STRING_TYPE, Imu, rospy.Time, ConnectionHeader, Tags) -> Iterable[float]
        msg_stamp = msg.header.stamp
        q = msg.orientation
        w = msg.angular_velocity
        a = msg.linear_acceleration
        cq = msg.orientation_covariance
        cw = msg.angular_velocity_covariance
        ca = msg.linear_acceleration_covariance
        return [msg_stamp.to_sec(), q.x, q.y, q.z, q.w, w.x, w.y, w.z, a.x, a.y, a.z] + list(cq) + list(cw) + list(ca)


class ExportMagnetometerToCSV(MessageToCSVExporterBase):
    """Export MagneticField messages to a CSV file."""

    def __init__(self, topic, csv_file, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Any, Any) -> None
        """
        :param topic: The topic to export (`sensor_msgs/MagneticField` type).
        :param csv_file: Path to the CSV file to store the trajectory.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExportMagnetometerToCSV, self).__init__(
            csv_file=csv_file, max_frequency=None, frequency_from_header_stamp=False,
            include_topics=(topic,), include_types=(MagneticField._type,), *args, **kwargs)  # noqa

    def _get_fields(self):
        return "stamp", "x", "y", "z", "c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8"

    def _msg_to_csv_row(self, topic, msg, stamp, header, tags):
        # type: (STRING_TYPE, MagneticField, rospy.Time, ConnectionHeader, Tags) -> Iterable[float]
        msg_stamp = msg.header.stamp
        f = msg.magnetic_field
        c = msg.magnetic_field_covariance
        return [msg_stamp.to_sec(), f.x, f.y, f.z] + list(c)


class ExportAzimuthToCSV(MessageToCSVExporterBase):
    """Export Azimuth messages to a CSV file."""

    def __init__(self, topic, csv_file, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Any, Any) -> None
        """
        :param topic: The topic to export (`compass_msgs/Azimuth` type).
        :param csv_file: Path to the CSV file to store the trajectory.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        from compass_msgs.msg import Azimuth
        super(ExportAzimuthToCSV, self).__init__(
            csv_file=csv_file, max_frequency=None, frequency_from_header_stamp=False,
            include_topics=(topic,), include_types=(Azimuth._type,), *args, **kwargs)  # noqa

    def _get_fields(self):
        return "stamp", "azimuth_rad", "azimuth_rad_cov"

    def _msg_to_csv_row(self, topic, msg, stamp, header, tags):
        # type: (STRING_TYPE, 'compass_msgs.msg.Azimuth', rospy.Time, ConnectionHeader, Tags) -> Iterable[float]
        msg_stamp = msg.header.stamp
        return [msg_stamp.to_sec(), msg.azimuth, msg.variance]


class ExportJointStatesToCSV(MessageToCSVExporterBase):
    """Export JointState messages to a CSV file."""

    def __init__(self, topic, csv_file, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Any, Any) -> None
        """
        :param topic: The topic to export (`sensor_msgs/JointState` type).
        :param csv_file: Path to the CSV file to store the trajectory.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExportJointStatesToCSV, self).__init__(
            csv_file=csv_file, max_frequency=None, frequency_from_header_stamp=False,
            include_topics=(topic,), include_types=(JointState._type,), *args, **kwargs)  # noqa

        self._fields = None
        self._joints = set()
        self._msgs = dict()
        self._stamps = list()

    def _get_fields(self):
        return self._fields

    def _msg_to_csv_row(self, topic, msg, stamp, header, tags):
        # type: (STRING_TYPE, JointState, rospy.Time, ConnectionHeader, Tags) -> Optional[Iterable[Any]]

        self._stamps.append(msg.header.stamp.to_sec())
        for v in self._msgs.values():
            v.append((float('nan'), float('nan'), float('nan')))

        for i, joint in enumerate(msg.name):
            if joint not in self._joints:
                self._joints.add(joint)
                self._msgs[joint] = [(float('nan'), float('nan'), float('nan'))] * len(self._stamps)
            p = msg.position[i] if len(msg.position) > i else float('nan')
            v = msg.velocity[i] if len(msg.velocity) > i else float('nan')
            e = msg.effort[i] if len(msg.effort) > i else float('nan')
            self._msgs[joint][-1] = (p, v, e)

        return None

    def _write_data_to_file(self, out_file):
        self._fields = ["stamp"]
        idxs = dict()
        for joint in sorted(list(self._joints)):
            idxs[joint] = len(self._fields)
            self._fields.extend([
                joint + "_pos",
                joint + "_vel",
                joint + "_eff",
            ])

        row_len = len(self._fields)
        for i, stamp in enumerate(self._stamps):
            row = [float('nan')] * row_len
            row[0] = stamp
            for j, v in self._msgs.items():
                idx = idxs[j]
                row[idx:(idx + 3)] = v[i]
            self._data.append(row)

        super(ExportJointStatesToCSV, self)._write_data_to_file(out_file)

    def reset(self):
        self._msgs = dict()
        self._joints = set()
        self._fields = None
        self._stamps = list()
        super(ExportJointStatesToCSV, self).reset()


class ExportGNSSToCSV(MessageToCSVExporterBase):
    """Export cmd_vel commands to a CSV file."""

    def __init__(self, topic, csv_file, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Any, Any) -> None
        """
        :param topic: The topic to export (`gps_common/GPSFix` or `sensor_msgs/NavSatFix` type).
        :param csv_file: Path to the CSV file to store the trajectory.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExportGNSSToCSV, self).__init__(
            csv_file=csv_file, max_frequency=None, frequency_from_header_stamp=False,
            include_topics=(topic,), include_types=(GPSFix._type, NavSatFix._type), *args, **kwargs)  # noqa

    def _get_fields(self):
        return "stamp", "lat", "lon", "alt", "status", "cov_type", "c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8"

    def _msg_to_csv_row(self, topic, msg, stamp, header, tags):
        # type: (STRING_TYPE, Union[GPSFix, NavSatFix], rospy.Time, ConnectionHeader, Tags) -> Optional[Iterable[Any]]
        return [msg.header.stamp.to_sec(), msg.latitude, msg.longitude, msg.altitude,
                msg.status.status, msg.position_covariance_type] + list(msg.position_covariance)


class DetectDamagedBaslerImages(DeserializedMessageFilter):
    """Detect images from Basler cameras that are damaged by incomplete CompressionBeyond packets. These images are
    mostly white noise with some flat parts. This is a statistical model that tries to find such images."""

    def __init__(self, drop_damaged=True, min_correlation=0.7, max_entropy=0.75, min_brightness=100, max_brightness=180,
                 output_csv=None, output_folder=None, add_tags=None, *args, **kwargs):
        # type: (bool, float, float, int, int, STRING_TYPE, STRING_TYPE, Optional[Set[STRING_TYPE]], Any, Any) -> None
        """
        :param drop_damaged: If true, damaged images will be dropped.
        :param min_correlation: Minimum cross-color-channel correlation of valid images.
        :param max_entropy: Maximum entropy of valid images.
        :param min_brightness: Minimum average brightness of damaged images (usually is around 128).
        :param max_brightness: Maximum average brightness of damaged images (usually is around 128).
        :param output_csv: Path to a CSV file where the info about damaged images should be saved.
        :param output_folder: Path to a folder where the damaged images should be saved.
        :param add_tags: Tags to be added to the modified TF messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(DetectDamagedBaslerImages, self).__init__(
            include_types=(Image._type, CompressedImage._type), *args, **kwargs)  # noqa

        self._drop_damaged = drop_damaged
        self._min_correlation = min_correlation
        self._max_entropy = max_entropy
        self._min_brightness = min_brightness
        self._max_brightness = max_brightness
        self._output_csv = output_csv
        self._output_folder = output_folder
        self._add_tags = add_tags

        self._output_csv_resolved = None
        self._output_folder_resolved = None

        self._damaged_images = []

        self._cv = CvBridge()

    def on_filtering_start(self):
        if self._output_csv is not None:
            self._output_csv_resolved = self.resolve_file(self._output_csv)
        if self._output_folder is not None:
            self._output_folder_resolved = self.resolve_file(self._output_folder)
            if not os.path.exists(self._output_folder_resolved):
                try:
                    os.makedirs(self._output_folder_resolved)
                except Exception as e:
                    print(str(e), file=sys.stderr)
                    self._output_folder_resolved = None

    def on_filtering_end(self):
        if self._output_csv_resolved:
            with open(self._output_csv_resolved, 'w', newline='') as csvfile:
                fieldnames = ['stamp_sec', 'stamp_nsec', 'topic', 'correlation', 'entropy', 'brightness']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
                for stamp, topic, correlation, entropy, brightness in self._damaged_images:
                    writer.writerow({
                        'stamp_sec': stamp.secs,
                        'stamp_nsec': stamp.nsecs,
                        'topic': topic,
                        'correlation': correlation,
                        'entropy': entropy,
                        'brightness': brightness,
                    })
                print("Saved info about %i damaged images to %s" % (
                    len(self._damaged_images), self._output_csv_resolved))

    def filter(self, topic, msg, stamp, header, tags):
        if msg._type == Image._type:
            cv_img = self._cv.imgmsg_to_cv2(msg, desired_encoding='bgr8')
        else:
            raw_img, err = decode(msg, topic)
            if raw_img is None:
                print(err, file=sys.stderr)
                return topic, msg, stamp, header, tags

            cv_img = self._cv.imgmsg_to_cv2(raw_img, desired_encoding='bgr8')

        is_damaged = False

        # First check if the average brightness is somewhere around the middle (it is white noise, so average color
        # should be grey)
        brightness = np.mean(cv_img)
        if self._min_brightness <= brightness <= self._max_brightness:
            correlation = self.color_correlation(cv_img)
            if correlation < 0.5:
                is_damaged = True
            else:
                entropy = self.entropy(cv_img)
                if correlation < self._min_correlation and entropy > self._max_entropy:
                    is_damaged = True

            if is_damaged:
                print("Damaged image:", topic, to_str(msg.header.stamp), correlation, entropy, brightness)
                self._damaged_images.append((msg.header.stamp, topic, correlation, entropy, brightness))
                if self._output_folder_resolved:
                    self.save_damaged_image(cv_img, msg.header.stamp, topic)
                if self._drop_damaged:
                    return None

        return topic, msg, stamp, header, tags

    def color_correlation(self, img):
        """Calculate the correlation between the color channels.

        :param img: Image message.
        :return: The average correlation coefficient.
        """
        if len(img.shape) != 3 or img.shape[-1] != 3:
            return 1.0

        corr_bg = np.corrcoef(img[:, :, 0].flat, img[:, :, 1].flat)[0, 1]
        corr_gr = np.corrcoef(img[:, :, 1].flat, img[:, :, 2].flat)[0, 1]
        return (corr_bg + corr_gr) / 2.0

    def entropy(self, img):
        """Calculates Shannon Entropy of an image.

        :param img: The image.
        :return: A value between 0 (no information) and 8 (max randomness).
        """
        if len(img.shape) == 3:
            img = img[:, :, 0]

        hist = cv2.calcHist([img], [0], None, [256], [0, 256])
        hist_norm = hist.ravel() / hist.sum()
        hist_norm = hist_norm[hist_norm > 0]

        # Apply Shannon Entropy formula: -sum(p * log2(p))
        entropy = -np.sum(hist_norm * np.log2(hist_norm))

        return entropy

    def save_damaged_image(self, cv_img, stamp, topic):
        filename = "%s-%s.png" % (to_str(stamp), to_valid_ros_name(topic, base_name=True))
        cv2.imwrite(os.path.join(self._output_folder_resolved, filename), cv_img)

    def reset(self):
        self._output_csv_resolved = None
        self._output_folder_resolved = None
        self._damaged_images = []

    def _str_params(self):
        parts = []
        parts.append('drop_damaged=' + repr(self._drop_damaged))
        parts.append('min_correlation=' + str(self._min_correlation))
        parts.append('max_entropy=' + str(self._max_entropy))
        parts.append('min_brightness=' + str(self._min_brightness))
        parts.append('max_brightness=' + str(self._max_brightness))
        if self._output_csv:
            parts.append('output_csv=' + self._output_csv)
        if self._output_folder:
            parts.append('output_folder=' + self._output_folder)
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class DropMessagesFromCSV(RawMessageFilter):
    """Drop messages listed in a CSV file with columns `stamp_sec`, `stamp_nsec` and `topic`."""

    def __init__(self, csv_file, additional_topics=None, fail_if_file_not_found=True, verbose=False, *args, **kwargs):
        """
        :param str csv_file: Path to the CSV.
        :param dict additional_topics: Dictionary mapping topics from CSV to a list of other topics whose messages
                                       should also be dropped if header.stamp is the same.
        :param bool fail_if_file_not_found: If true, the filter will fail if the CSV file is not found.
        :param bool verbose: If True, every dropped message will be logged to console.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(DropMessagesFromCSV, self).__init__(*args, **kwargs)
        self._csv_file = csv_file
        self._additional_topics = additional_topics if additional_topics is not None else {}
        self._fail_if_file_not_found = fail_if_file_not_found
        self._verbose = verbose

        self._to_drop = {}
        self._to_drop_topics = TopicSet()

    def on_filtering_start(self):
        csv_file = self.resolve_file(self._csv_file)
        if not os.path.exists(csv_file):
            if self._fail_if_file_not_found:
                raise RuntimeError("File " + csv_file + " does not exist.")
            self._to_drop = {}
        else:
            with open(csv_file, 'r') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    stamp = rospy.Time(int(row['stamp_sec']), int(row['stamp_nsec']))
                    topic = str(row['topic']).lstrip('/')
                    topics = [topic] + [t.lstrip('/') for t in self._additional_topics.get(topic, [])]
                    for t in topics:
                        if t not in self._to_drop:
                            self._to_drop[t] = set()
                        self._to_drop[t].add(stamp)
            print("Loaded CSV with", sum(map(len, self._to_drop.values())), "messages to drop:", csv_file)

        self._to_drop_topics = TopicSet(self._to_drop.keys())
        if not self._include_topics:
            self._include_topics = self._to_drop_topics

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        if topic not in self._to_drop_topics:
            return topic, datatype, data, md5sum, pytype, stamp, header, tags

        to_drop = self._to_drop[topic.lstrip('/')]

        msg_header = deserialize_header(data, pytype)
        if msg_header is None:
            msg_stamp = stamp
        else:
            msg_stamp = msg_header.stamp

        if msg_stamp in to_drop:
            if self._verbose:
                print("drop", topic, to_str(msg_stamp))
            return None

        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def _str_params(self):
        parts = []
        parts.append('csv_file=%s' % (self._csv_file,))
        parts.append('additional_topics=%r' % (repr(self._additional_topics),))
        parts.append('fail_if_file_not_found=%s' % (repr(self._fail_if_file_not_found),))
        parent_params = super(DropMessagesFromCSV, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class ExportMessageInfoToCSV(RawMessageFilter):
    """Export message info in a CSV file with columns `stamp_sec`, `stamp_nsec` and `topic`."""

    def __init__(self, csv_file, use_header_stamp=False, *args, **kwargs):
        """
        :param str csv_file: Path to the CSV.
        :param bool use_header_stamp: If false, the receive timestamp will be used. Otherwise, header.stamp is used if
                                      the message has a header (otherwise receive stamp is used).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(ExportMessageInfoToCSV, self).__init__(*args, **kwargs)
        self._csv_file = csv_file
        self._use_header_stamp = use_header_stamp

        self._message_info = []

    def on_filtering_start(self):
        self._message_info = []

    def on_filtering_end(self):
        csv_file = self.resolve_file(self._csv_file)
        with open(csv_file, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=("stamp_sec", "stamp_nsec", "topic"))
            writer.writeheader()
            for row in self._message_info:
                writer.writerow(row)

        print("Saved CSV with", len(self._message_info), "rows:", csv_file)

    def consider_message(self, topic, datatype, stamp, header, tags):
        if not super(ExportMessageInfoToCSV, self).consider_message(topic, datatype, stamp, header, tags):
            return False

        # For increased efficiency, if using receive stamps, we never need to pass the message to filter()
        if not self._use_header_stamp:
            self._message_info.append({
                'stamp_sec': stamp.secs,
                'stamp_nsec': stamp.nsecs,
                'topic': topic,
            })
            return False

        # if using header.stamp, we have to look into the message
        return True

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        # this is only called when self._use_header_stamp is True
        msg_header = deserialize_header(data, pytype)
        if msg_header is None:
            msg_stamp = stamp
        else:
            msg_stamp = msg_header.stamp

        self._message_info.append({
            'stamp_sec': msg_stamp.secs,
            'stamp_nsec': msg_stamp.nsecs,
            'topic': topic,
        })

        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def _str_params(self):
        parts = []
        parts.append('csv_file=%s' % (self._csv_file,))
        parts.append('use_header_stamp=%r' % (self._use_header_stamp,))
        parent_params = super(ExportMessageInfoToCSV, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class StampTwist(DeserializedMessageFilter):
    """Adjust some static transforms."""

    def __init__(self, source_topic, stamped_topic=None, frame_id="base_link", *args, **kwargs):
        # type: (STRING_TYPE, Optional[STRING_TYPE], STRING_TYPE, Any, Any) -> None
        """
        :param source_topic: The Twist topic to stamp.
        :param stamped_topic: The stamped Twist topic to create.
        :param frame_id: The frame_id to use in the stamped messages.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(StampTwist, self).__init__(include_topics=[source_topic], *args, **kwargs)

        self._source_topic = source_topic
        self._stamped_topic = stamped_topic if stamped_topic is not None else (source_topic + "_stamped")
        self._frame_id = frame_id

        self._connection_header = create_connection_header(self._stamped_topic, TwistStamped)

    def filter(self, topic, msg, stamp, header, tags):
        stamped_msg = TwistStamped()
        stamped_msg.header.frame_id = self._frame_id
        stamped_msg.header.stamp = stamp
        stamped_msg.twist = msg

        return [
            (topic, msg, stamp, header, tags),
            (self._stamped_topic, stamped_msg, stamp, self._connection_header, tags_for_generated_msg(tags))
        ]

    def _str_params(self):
        parts = []
        parts.append('%s=>%s (frame %s)' % (self._source_topic, self._stamped_topic, self._frame_id))
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class FixExtrinsicsFromIKalibr(DeserializedMessageFilterWithTF):
    """Apply extrinsic calibrations from IKalibr."""

    def __init__(self, param_file, ref_imu_frame="imu", frames=None, add_tags=None, *args, **kwargs):
        # type: (STRING_TYPE, STRING_TYPE, Optional[Dict[STRING_TYPE, Dict[STRING_TYPE, STRING_TYPE]]], Optional[Set[STRING_TYPE]], Any, Any) -> None  # noqa
        """
        :param param_file: Path to ikalibr_param.yaml file - the result of extrinsic calibration.
        :param ref_imu_frame: Frame of the reference IMU towards which everything is calibrated.
        :param frames: The TF frames to fix. Keys are the topic names used in iKalibr. Values are dicts with keys
                       "sensor_frame" and optionally "adjust_frame" (if not specified, "sensor_frame" is used).
                       "adjust_frame" specifies the frame whose transform should be changed, which can be useful
                       if you want to change a transform somewhere in the middle of the TF chain and not the last one.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(FixExtrinsicsFromIKalibr, self).__init__(
            include_topics=["tf", "tf_static"], include_types=(TFMessage._type,), *args, **kwargs)  # noqa

        self._ref_imu_frame = ref_imu_frame
        self._frames = frames if frames is not None else dict()
        self._param_file = param_file
        self._add_tags = add_tags

        with open(param_file, 'r') as f:
            self._params = yaml.safe_load(f)

        extri = self._params["CalibParam"]["EXTRI"]
        tfs = dict()
        for key, values in extri.items():
            for item in values:
                ikalibr_name = item["key"]
                value = item["value"]
                if ikalibr_name not in tfs:
                    tfs[ikalibr_name] = Transform()
                if key.startswith("POS_"):
                    msg = Vector3(value["r0c0"], value["r1c0"], value["r2c0"])
                    tfs[ikalibr_name].translation = msg
                elif key.startswith("SO3_"):
                    msg = Quaternion(value["qx"], value["qy"], value["qz"], value["qw"])
                    tfs[ikalibr_name].rotation = msg

        print("Read %i transforms from %s." % (len(tfs), param_file))

        self._sensor_to_adjust_frame = dict()
        self._adjust_frame_to_sensor = dict()
        self._sensor_transforms = dict()
        for ikalibr_name, item in self._frames.items():
            sensor_frame = item["sensor_frame"]
            adjust_frame = item.get("adjust_frame", sensor_frame)

            if adjust_frame in self._adjust_frame_to_sensor:
                raise RuntimeError(
                    "Duplicate appearance of adjust frame %s. This is invalid configuration!" % (adjust_frame,))

            self._sensor_transforms[sensor_frame] = tfs[ikalibr_name]
            self._sensor_to_adjust_frame[sensor_frame] = adjust_frame
            self._adjust_frame_to_sensor[adjust_frame] = sensor_frame

        self._tf = BufferCore()

    @staticmethod
    def fix_transform_type(t):
        """Convert from the ad-hoc rosbag type to true Transform type (needed by ros_numpy)."""
        return Transform(
            Vector3(t.translation.x, t.translation.y, t.translation.z),
            Quaternion(t.rotation.x, t.rotation.y, t.rotation.z, t.rotation.w))

    @staticmethod
    def fix_transform_stamped_type(t):
        """Convert from the ad-hoc rosbag type to true TransformStamped type (needed by tf2_py)."""
        return TransformStamped(
            Header(t.header.seq, t.header.stamp, t.header.frame_id),
            t.child_frame_id,
            FixExtrinsicsFromIKalibr.fix_transform_type(t.transform))

    def filter(self, topic, msg, stamp, header, tags):
        if topic in self._tf_static_topics:
            for transform in msg.transforms:
                self._tf.set_transform_static(self.fix_transform_stamped_type(transform), "filter_bag")
        else:
            for transform in msg.transforms:
                self._tf.set_transform(self.fix_transform_stamped_type(transform), "filter_bag")

        latest = rospy.Time(0)
        for transform in msg.transforms:
            if transform.child_frame_id in self._adjust_frame_to_sensor:
                adjust_frame = transform.child_frame_id
                sensor_frame = self._adjust_frame_to_sensor[adjust_frame]
                parent_frame = transform.header.frame_id

                sensor_tf = numpify(self._sensor_transforms[sensor_frame])
                t_parent_imu = numpify(
                    self._tf.lookup_transform_core(self._ref_imu_frame, parent_frame, latest).transform)
                t_adjust_parent = numpify(self.fix_transform_type(transform.transform))
                t_sensor_adjust = numpify(self._tf.lookup_transform_core(adjust_frame, sensor_frame, latest).transform)

                t_correction = \
                    np.matmul(np.matmul(inv(np.matmul(t_parent_imu, t_adjust_parent)), sensor_tf), inv(t_sensor_adjust))
                t_adjust_parent = np.matmul(t_adjust_parent, t_correction)
                transform.transform = msgify(Transform, t_adjust_parent)

                print("Adjusted transform %s->%s by %.3f m (for sensor %s)." % (
                      parent_frame, adjust_frame, np.linalg.norm(t_correction[:3, 3]), sensor_frame))
                tags.add(MessageTags.CHANGED)
                if self._add_tags:
                    tags = tags.union(self._add_tags)

        return topic, msg, stamp, header, tags

    def reset(self):
        self._tf.clear()
        super(FixExtrinsicsFromIKalibr, self).reset()

    def _str_params(self):
        parts = []
        parts.append('param_file=' + self._param_file)
        parts.append('ref_imu_frame=' + self._ref_imu_frame)
        parts.append('frames=' + ",".join(self._frames.keys()))
        parent_params = self._default_str_params(include_types=False)
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class AddSubtitles(NoMessageFilter):
    """Read subtitles from a text file and add them to the output as std_msgs/String messages."""

    def __init__(self, subtitles_file, topic, latch=False, time_offset=0.0):
        """
        :param subtitles_file: The file to read subtitles from.
        :param topic: The topic to put the subtitles on.
        :param latch: Whether the topic should be latched.
        :param time_offset: Relative time offset to add to the absolute subtitle time.
        """
        super(AddSubtitles, self).__init__()
        self._subtitles_file = subtitles_file
        self._topic = topic
        self._start_time = None
        self._latch = latch
        self._time_offset = rospy.Duration(time_offset)

        self._subtitles = None

    def on_filtering_start(self):
        super(AddSubtitles, self).on_filtering_start()
        self._start_time = rospy.Time(self._bag.get_start_time()) + self._time_offset
        subtitles_file = self.resolve_file(self._subtitles_file)
        self._subtitles = self._parse_subtitles(subtitles_file)
        print("Read", len(self._subtitles), "subtitles from file", subtitles_file)

    def _parse_subtitles(self, subtitles_file):
        if subtitles_file.endswith(".srt"):
            return self._parse_srt(subtitles_file)
        raise RuntimeError("Unsupported subtitles type. Only SRT is supported so far.")

    def _parse_srt(self, subtitles_file):
        class State(Enum):
            INIT = 0
            COUNTER_READ = 1
            TIME_READ = 2

        state = State.INIT
        start_time = None
        end_time = None
        lines = list()

        subtitles = list()

        def parse_time_str(time_str):
            h, m, s = time_str.split(':')
            s, ss = s.split(',')
            return rospy.Duration(
                int(h, base=10) * 3600 + int(m, base=10) * 60 + int(s, base=10),
                int(float('0.' + ss) * 1e9))

        with open(subtitles_file, 'r') as f:
            for line in f:
                line = line.strip()
                if state == State.INIT:
                    _ = int(line)
                    state = State.COUNTER_READ
                elif state == State.COUNTER_READ:
                    start_time_str, end_time_str = line.split(" --> ")
                    start_time = parse_time_str(start_time_str)
                    end_time = parse_time_str(end_time_str)
                    state = State.TIME_READ
                elif state == State.TIME_READ:
                    if len(line) == 0:
                        subtitles.append((start_time, end_time, "\n".join(lines)))
                        lines = list()
                        state = State.INIT
                    else:
                        lines.append(line)

        # in case the last empty line is missing
        if len(lines) > 0:
            subtitles.append((start_time, end_time, "\n".join(lines)))

        return subtitles

    def extra_initial_messages(self):
        connection_header = create_connection_header(self._topic, String, self._latch)
        for start_time, end_time, subtitle in self._subtitles:
            msg = String(data=subtitle)
            abs_time = self._start_time + start_time
            yield self._topic, msg, abs_time, connection_header, {MessageTags.GENERATED}

    def _str_params(self):
        params = ["topic=" + self._topic, "subtitles_file=" + self.resolve_file(self._subtitles_file)]
        if self._time_offset != rospy.Duration(0, 0):
            params.append("time_offset=%f" % (self._time_offset.to_sec(),))
        parent_params = super(AddSubtitles, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)


class DumpRobotModel(NoMessageFilter):
    """Read robot model URDF from ROS params and store it in a file."""

    def __init__(self, urdf_file, param="robot_description", remove_comments=False, pretty_print=False,
                 run_on_start=True):
        """
        :param urdf_file: Path to the URDF file. If relative, it will be resolved relative to the bag set by set_bag.
        :param param: The parameter where robot model should be read.
        :param remove_comments: Whether comments should be removed from the output URDF file.
        :param pretty_print: Whether to pretty-print the URDF file (if False, the original layout is preserved).
        :param run_on_start: If true, the model will be exported before messages are processed. If false, the model will
                             be exported after processing all messages.
        """
        super(DumpRobotModel, self).__init__()
        self._urdf_file = urdf_file
        self._param = param
        self._remove_comments = remove_comments
        self._pretty_print = pretty_print
        self._run_on_start = run_on_start

    def on_filtering_start(self):
        super(DumpRobotModel, self).on_filtering_start()

        if self._run_on_start:
            self.dump_model()

    def on_filtering_end(self):
        super(DumpRobotModel, self).on_filtering_end()

        if not self._run_on_start:
            self.dump_model()

    def dump_model(self):
        urdf = self._get_param(self._param)
        if urdf is None:
            print('Robot model not found on parameter', self._param, file=sys.stderr)
            return

        if self._remove_comments or self._pretty_print:
            parser = etree.XMLParser(remove_comments=self._remove_comments, encoding='utf-8')
            tree = etree.fromstring(urdf.encode('utf-8'), parser=parser)
            urdf = etree.tostring(tree, encoding='utf-8', xml_declaration=True,
                                  pretty_print=self._pretty_print).decode('utf-8')

        dest = self.resolve_file(self._urdf_file)
        with open(dest, 'w+') as f:
            print(urdf, file=f)
        print("Robot model from parameter", self._param, "exported to", dest)

    def _str_params(self):
        params = ["param=" + self._param, "urdf_file=" + self.resolve_file(self._urdf_file)]
        parent_params = super(DumpRobotModel, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)


class UpdateRobotModel(NoMessageFilter):
    """Read robot model from URDF file and update it in the ROS parameters."""

    def __init__(self, urdf_file, param="robot_description"):
        """
        :param urdf_file: Path to the URDF file. If relative, it will be resolved relative to the bag set by set_bag.
        :param param: The parameter where robot model should be stored.
        """
        super(UpdateRobotModel, self).__init__()
        self._urdf_file = urdf_file
        self._param = param

    def on_filtering_start(self):
        super(UpdateRobotModel, self).on_filtering_start()

        src = self.resolve_file(self._urdf_file)
        if not os.path.exists(src):
            print('Cannot find robot URDF file', src, file=sys.stderr)
            return

        with open(src, 'r', encoding='utf-8') as f:
            urdf = f.read()

        self._set_param(self._param, urdf)
        print('Robot model from %s set to parameter %s' % (src, self._param))

    def _str_params(self):
        params = ["param=" + self._param, "urdf_file=" + self.resolve_file(self._urdf_file)]
        parent_params = super(UpdateRobotModel, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)


class RemovePasswordsFromParams(NoMessageFilter):
    """Search for passwords in the ROS parameters and remove them."""

    def __init__(self):
        super(RemovePasswordsFromParams, self).__init__()
        self._removed_passwords = set()

    def on_filtering_start(self):
        super(RemovePasswordsFromParams, self).on_filtering_start()
        self.remove_passwords(self._params)
        if len(self._removed_passwords) > 0:
            print("Removed passwords %r" % (self._removed_passwords,))
        else:
            print("No passwords removed")

    def remove_passwords(self, params, prefix=""):
        if isinstance(params, dict):
            keys_to_remove = set()
            for key, val in params.items():
                if key.lower().strip().startswith("pass"):
                    keys_to_remove.add(key)
                else:
                    self.remove_passwords(val, "/".join((prefix, key)))
            for key in keys_to_remove:
                del params[key]
                self._removed_passwords.add("/".join((prefix, key)))
        elif isinstance(params, list):
            for i in range(len(params)):
                self.remove_passwords(params[i], "%s[%i]" % (prefix, i))

    def _str_params(self):
        params = ["removed_passwords=" + repr(self._removed_passwords)]
        parent_params = super(RemovePasswordsFromParams, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)


class UpdateParams(NoMessageFilter):
    """Update ROS parameters."""

    def __init__(self, update=None, remove=None):
        super(UpdateParams, self).__init__()
        self._update = update if update is not None else dict()
        self._remove = remove if remove is not None else list()

        self._updated_params = set()
        self._removed_params = set()

    def on_filtering_start(self):
        super(UpdateParams, self).on_filtering_start()
        self.update_params(self._params, self._update)
        for param in self._remove:
            self.remove_param(self._params, param)

        if len(self._updated_params) > 0:
            print("Updated %i ROS parameters." % (len(self._updated_params),))
        if len(self._removed_params) > 0:
            print("Removed %i ROS parameters." % (len(self._removed_params),))

    def update_params(self, params, new_params, prefix=""):
        if not isinstance(new_params, dict) or not isinstance(params, dict):
            return
        for k, v in new_params.items():
            if k not in params:
                params[k] = v
                self._updated_params.add("/".join((prefix, k)))
            else:
                self.update_params(params[k], v, "/".join((prefix, k)))

    def remove_param(self, params, param, prefix=""):
        if isinstance(params, dict):
            for key, val in params.items():
                if key == param:
                    del params[key]
                    self._removed_params.add("/".join((prefix, key)))
                    return
                elif "/" in param:
                    param_key, param_rest = param.split("/", maxsplits=1)
                    if param_key == key:
                        self.remove_param(val, param_rest, "/".join((prefix, key)))

    def _str_params(self):
        params = ["removed_params=" + repr(self._removed_params), "updated_params=" + repr(self._updated_params)]
        parent_params = super(UpdateParams, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)


class FixSpotCams(RawMessageFilter):
    """Fix a problem with Spot robot cameras that publish a bit weird message header."""

    def __init__(self, *args, **kwargs):
        """
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(FixSpotCams, self).__init__(*args, **kwargs)

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        header["message_definition"] = CompressedImage._full_text
        header["md5sum"] = md5sum = CompressedImage._md5sum
        header["type"] = datatype = CompressedImage._type
        pytype = CompressedImage

        return topic, datatype, data, md5sum, pytype, stamp, header, tags


class MaxMessageSize(RawMessageFilter):
    """Drop messages larger than the specified message size."""

    def __init__(self, size_limit, *args, **kwargs):
        """
        :param int size_limit: Maximum message size [B].
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(MaxMessageSize, self).__init__(*args, **kwargs)
        self.size_limit = size_limit

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        if len(data) > self.size_limit:
            return None

        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def _str_params(self):
        params = ["size=%d B" % self.size_limit]
        parent_params = super(MaxMessageSize, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument('--max-message-size', type=int, help='Remove all messages larger than this size [B]')

    @staticmethod
    def yaml_config_args():
        return 'max_message_size',

    @staticmethod
    def process_cli_args(filters, args):
        if args.max_message_size:
            filters.append(MaxMessageSize(size_limit=args.max_message_size))


class MakeLatched(RawMessageFilter):
    """Make topics latched."""

    def __init__(self, *args, **kwargs):
        """
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(MakeLatched, self).__init__(*args, **kwargs)

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        header["latching"] = "1"
        return topic, datatype, data, md5sum, pytype, stamp, header, tags


class CompressImages(DeserializedMessageFilter):
    """Compress an Image topic to CompressedImage."""

    def __init__(self, include_types=None, only_color=False, only_depth=False, transport=None, transport_params=None,
                 transport_mapping=None, format_mapping=None, *args, **kwargs):
        """
        :param list include_types: Types of messages to work on. The default is sensor_msgs/Image.
        :param bool only_color: If true, only color images will be processed (i.e. 3 or 4 channels, or 1 8-bit channel).
        :param bool only_depth: If true, only depth images will be processed (i.e. 1 16-bit channel).
        :param str transport: What image_transport to used. If not provided, 'compressedDepth' will be used for depth
                              images and 'compressed' for the rest.
        :param dict transport_params: Parameters of image transport(s). Keys are transport names (e.g. 'compressed'),
                                      values are the publisher dynamic reconfigure parameters.
        :param dict transport_mapping: Maps the message's 'encoding' field values to image transports. This overrides
                                       the default transport set by 'transport' arg or the autodetected one.
        :param dict format_mapping: Maps the message's 'encoding' field values encoder 'format's (i.e. 'jpg', 'png',
                                    'rvl' etc.).
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(CompressImages, self).__init__(
            include_types=include_types if include_types is not None else ['sensor_msgs/Image'], *args, **kwargs)
        self._cv = CvBridge()

        self.only_color = only_color
        self.only_depth = only_depth

        self.transport = transport
        self.transport_mapping = transport_mapping if transport_mapping is not None else {}
        self.transport_params = transport_params if transport_params is not None else {}
        # map of the raw images' encoding parameter to output image format
        self.format_mapping = format_mapping if format_mapping is not None else {}

    def filter(self, topic, msg, stamp, header, tags):
        enc = msg.encoding
        is_color = isColor(enc) or isMono(enc) or isBayer(enc) or enc == YUV422
        is_depth = isDepth(enc)

        if (self.only_color and not is_color) or (self.only_depth and not is_depth):
            return topic, msg, stamp, header, tags

        transport = ("compressedDepth" if is_depth else "compressed") if self.transport is None else self.transport
        transport = self.transport_mapping.get(enc, transport)

        compressed_msg, compressed_topic, err = self.get_image_for_transport(msg, topic, transport)
        # If encoding using compressedDepth fails, try with compressed
        if compressed_msg is None and transport == "compressedDepth":
            transport = "compressed"
            compressed_msg, compressed_topic, err = self.get_image_for_transport(msg, topic, transport)

        if compressed_msg is None:
            print('Error converting image: ' + str(err), file=sys.stderr)
            return topic, msg, stamp, header, tags

        return compressed_topic, compressed_msg, stamp, header, tags_for_generated_msg(tags)

    def get_image_for_transport(self, msg, topic, transport):
        compressed_topic = rospy.names.ns_join(topic, transport)

        config = self.transport_params.get(transport, {})
        if msg.encoding in self.format_mapping:
            config = copy.deepcopy(config)
            config["format"] = self.format_mapping[msg.encoding]
        # 16-bit images cannot be compressed to JPEG
        if transport == "compressed" and bitDepth(msg.encoding) > 8:
            config["format"] = "png"
        # Melodic doesn't have RVL
        if transport == "compressedDepth" and not has_rvl():
            config["format"] = "png"

        compressed_msg, err = encode(msg, compressed_topic, config)

        return compressed_msg, compressed_topic, err

    def _str_params(self):
        parts = []
        if self.only_color:
            parts.append('only_color')
        if self.only_depth:
            parts.append('only_depth')
        if self.transport:
            parts.append('transport=' + self.transport)
        if len(self.transport_params) > 0:
            parts.append('transport_params=%r' % (self.transport_params,))
        if len(self.transport_mapping) > 0:
            parts.append('transport_mapping=%r' % (self.transport_mapping,))
        if len(self.format_mapping) > 0:
            parts.append('format_mapping=%r' % (self.format_mapping,))
        parent_params = super(CompressImages, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class DecompressImages(DeserializedMessageFilter):
    """Decompress images on a CompressedImage topic into Image messages."""

    def __init__(self, include_types=None, desired_encodings=None, transport=None, transport_params=None,
                 *args, **kwargs):
        """
        :param list include_types: Types of messages to work on. The default is sensor_msgs/Image.
        :param dict desired_encodings: Maps topic names to target 'encoding' values of the decoded Image. By default,
                                       'passthrough' encoding is used which just passes the encoding in which the
                                       compressed image was stored.
        :param str transport: If nonempty, overrides the autodetected image_transport.
        :param dict transport_params: Parameters of image transport(s). Keys are transport names (e.g. 'compressed'),
                                      values are the subscriber dynamic reconfigure parameters.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(DecompressImages, self).__init__(
            include_types=include_types if include_types is not None else ['sensor_msgs/CompressedImage'],
            *args, **kwargs)

        self._cv = CvBridge()

        # map from topic to desired encoding of the raw images (one of the strings in sensor_msgs/image_encodings.h)
        self.desired_encodings = desired_encodings if desired_encodings is not None else {}
        self.transport = transport
        self.transport_params = transport_params if transport_params is not None else {}

    def filter(self, topic, msg, stamp, header, tags):
        transport = self.transport
        raw_topic = topic
        if transport is None and "/" in topic:
            raw_topic, transport = topic.rsplit("/", 1)

        params = self.transport_params.get(transport, {}) if transport is not None else {}
        raw_img, err = decode(msg, topic, params)
        if raw_img is None:
            print('Error converting image: ' + str(err), file=sys.stderr)
            return topic, msg, stamp, header, tags

        desired_encoding = self.desired_encodings.get(topic, 'passthrough')
        if desired_encoding == 'passthrough':
            return raw_topic, raw_img, stamp, header, tags_for_generated_msg(tags)

        compressed_fmt, compressed_depth_fmt, err = guess_any_compressed_image_transport_format(msg)
        if compressed_fmt is None and compressed_depth_fmt is None:
            print('Error converting image to desired encoding: ' + str(err), file=sys.stderr)
            return raw_topic, raw_img, stamp, header, tags_for_generated_msg(tags)

        raw_encoding = compressed_fmt.rawEncoding if compressed_fmt is not None else compressed_depth_fmt.rawEncoding
        if desired_encoding == raw_encoding:
            return raw_topic, raw_img, stamp, header, tags_for_generated_msg(tags)

        try:
            cv_img = self._cv.imgmsg_to_cv2(raw_img, desired_encoding)
            return raw_topic, self._cv.cv2_to_imgmsg(cv_img, desired_encoding, raw_img.header), stamp, header
        except CvBridgeError as e:
            print('Error converting image to desired encoding: ' + str(e), file=sys.stderr)
            return raw_topic, raw_img, stamp, header, tags_for_generated_msg(tags)

    def _str_params(self):
        parts = []
        if len(self.desired_encodings):
            parts.append('desired_encodings=%r' % (self.desired_encodings,))
        if self.transport:
            parts.append('transport=' + self.transport)
        if len(self.transport_params) > 0:
            parts.append('transport_config=%r' % (self.transport_params,))
        parent_params = super(DecompressImages, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)

    @staticmethod
    def add_cli_args(parser):
        parser.add_argument('--decompress-images', action='store_true', help='Decompress all images')

    @staticmethod
    def yaml_config_args():
        return 'decompress_images',

    @staticmethod
    def process_cli_args(filters, args):
        if hasattr(args, 'decompress_images') and args.decompress_images:
            filters.append(DecompressImages())


class DepthImagePreview(DeserializedMessageFilter):
    """Convert 16- or 32-bit depth images to 8-bit mono or color images."""

    def __init__(self, include_types=None, new_topic_suffix="_preview", keep_orig_image=True, normalize=False,
                 num_normalization_samples=1, adaptive_normalization=True, fixed_normalization_bounds=None,
                 colormap=None, *args, **kwargs):
        """
        :param list include_types: Types of messages to work on. The default is sensor_msgs/Image.
        :param str new_new_topic_suffix: Suffix to add to the new image topic.
        :param bool keep_orig_image: Whether the original image topic should be kept or not.
        :param bool normalize: Whether to normalize the images.
        :param float num_normalization_samples: If greater than 1, this filter samples N images from the bag file to
                                                compute a global normalization factor. If lower than 1, it specifies a
                                                percentage of images to take as samples. If equal to 1, each image is
                                                normalized independently and random access to the bag file is not
                                                needed.
        :param bool adaptive_normalization: If num_normalization_samples is 1 (independent normalization), collect
                                            info about normalization range from the so-far seen samples. This means
                                            the normalization range is changing, but it is only growing, never
                                            decreasing.
        :param tuple fixed_normalization_bounds: If not None, normalization is not estimated and is not adapted, but
                                                 these bounds are used all the time. For 16UC1 images, these bounds
                                                 are multiplied by 1000 to comply with the Kinect standards (conversion
                                                 to millimeters). The tuple should contain exactly 2 float elements
                                                 meaning min value and max value.
        :param str colormap: If None, the conversion will be into 8-bit grayscale. If not None, the passed string should
                             name an installed matplotlib colormap (e.g. 'jet') to use and the resulting images will be
                             3-channel BGR8 images.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(DepthImagePreview, self).__init__(
            include_types=include_types if include_types is not None else ['sensor_msgs/Image'],
            *args, **kwargs)

        self._cv = CvBridge()

        self.new_topic_suffix = new_topic_suffix
        self.keep_orig_image = keep_orig_image
        self.normalize = normalize
        self.num_normalization_samples = num_normalization_samples
        self.adaptive_normalization = adaptive_normalization
        self.fixed_normalization_bounds = fixed_normalization_bounds
        if num_normalization_samples != 1 or fixed_normalization_bounds is not None:
            self.adaptive_normalization = False
        self.colormap = None
        if colormap is not None:
            try:
                self.colormap = cmap.get_cmap(colormap)
            except KeyError:
                self.colormap = cmap.get_cmap(None)
                print('Colormap {} not found, using default colormap.'.format(colormap), file=sys.stderr)
        self.min_values = {}
        self.max_values = {}

    def set_bag(self, bag):
        super(DepthImagePreview, self).set_bag(bag)

        if self.num_normalization_samples == 1 or self.fixed_normalization_bounds is not None:
            return

        def connection_filter(topic, datatype, md5sum, msg_def, header):
            return datatype == "sensor_msgs/Image"

        conns = bag._get_connections(self._include_topics, connection_filter)
        image_topics = set([c.topic for c in conns])
        depth_image_topics = []

        for topic in image_topics:
            for _, msg, stamp in bag.read_messages([topic]):
                if msg.encoding == TYPE_16UC1 or msg.encoding == TYPE_32FC1:
                    depth_image_topics.append(topic)
                break

        _, topic_info = bag.get_type_and_topic_info(depth_image_topics)
        for topic in depth_image_topics:
            info = topic_info[topic]
            num_messages = info.message_count
            num_samples = int(self.num_normalization_samples) if self.num_normalization_samples > 1 else \
                int(self.num_normalization_samples * num_messages)
            num_samples = min(num_samples, num_messages)
            connections = bag._get_connections([topic], self.connection_filter)
            entries = list(bag._get_entries(connections, self._min_stamp, self._max_stamp))
            if len(entries) < num_samples:
                entries = list(bag._get_entries(connections))
            idx = np.round(np.linspace(0, len(entries) - 1, num_samples)).astype(int)
            for i in idx:
                entry = entries[i]
                _, msg, _ = bag._read_message((entry.chunk_pos, entry.offset))
                msg_data = self._cv.imgmsg_to_cv2(msg)
                if msg.encoding == TYPE_16UC1:
                    msg_data = msg_data[msg_data != 0]
                    msg_data = msg_data[msg_data != 65535]
                else:
                    msg_data = msg_data[np.isfinite(msg_data)]
                if topic not in self.min_values:
                    self.min_values[topic] = np.min(msg_data)
                    self.max_values[topic] = np.max(msg_data)
                else:
                    self.min_values[topic] = min(self.min_values[topic], np.min(msg_data))
                    self.max_values[topic] = max(self.max_values[topic], np.max(msg_data))

    def filter(self, topic, msg, stamp, header, tags):
        if msg.encoding != TYPE_16UC1 and msg.encoding != TYPE_32FC1:
            # If we encounter a non-depth topic, exclude it from this filter. So only the first message from non-depth
            # topics will be needlessly deserialized because of this filter.
            self._exclude_topics = TopicSet(list(self._exclude_topics) + [topic])
            return topic, msg, stamp, header, tags

        in_bit_depth = 16 if msg.encoding == TYPE_16UC1 else 32
        new_bit_depth = 8

        try:
            img = self._cv.imgmsg_to_cv2(msg)
            orig_img = img

            if self.normalize or msg.encoding == TYPE_32FC1:
                if msg.encoding == TYPE_16UC1:
                    img_valid = img[(img != 0) & (img != 65535)]
                else:
                    img_valid = img[np.isfinite(img)]

                if self.fixed_normalization_bounds is not None:
                    if topic not in self.min_values:
                        min_bound, max_bound = self.fixed_normalization_bounds
                        if msg.encoding == TYPE_16UC1:
                            min_bound = int(min_bound * 1000)
                            max_bound = int(max_bound * 1000)
                        self.min_values[topic] = min_bound
                        self.max_values[topic] = max_bound

                elif self.adaptive_normalization:
                    im_min = np.min(img_valid)
                    im_max = np.max(img_valid)
                    if topic not in self.min_values:
                        self.min_values[topic] = im_min
                        self.max_values[topic] = im_max
                    else:
                        self.min_values[topic] = min(self.min_values[topic], im_min)
                        self.max_values[topic] = max(self.max_values[topic], im_max)

                im_min = self.min_values[topic] if topic in self.min_values else np.min(img_valid)
                im_max = self.max_values[topic] if topic in self.max_values else np.max(img_valid)
                im_range = max(im_max - im_min, 1)
                img = (img - im_min) / float(im_range)
                img = np.clip(img, 0.0, 1.0)

                if msg.encoding == TYPE_16UC1:
                    img[orig_img == 0] = 0.0
                    img[orig_img == 65535] = 1.0
                else:
                    img[np.isinf(orig_img)] = ~np.signbit(orig_img[np.isinf(orig_img)])
                    img[np.isnan(orig_img)] = 0.0
            else:
                img = img / (np.power(2.0, in_bit_depth) - 1)

            if self.colormap is None:
                img = img * (np.power(2.0, new_bit_depth) - 1)
                img = img.astype(np.uint8)
                desired_encoding = MONO8
            else:
                img = self.colormap(img, bytes=True)
                if self.colormap.is_gray():
                    img = img[:, :, 0]
                    desired_encoding = MONO8
                else:
                    img = img[:, :, (2, 1, 0)]
                    desired_encoding = BGR8

            new_topic = topic + self.new_topic_suffix
            new_msg = self._cv.cv2_to_imgmsg(img, desired_encoding, msg.header)
            new_image = new_topic, new_msg, stamp, header, tags_for_generated_msg(tags)
            if not self.keep_orig_image:
                return new_image
            return [(topic, msg, stamp, header), new_image]
        except CvBridgeError as e:
            print('Error converting image to desired encoding: ' + str(e), file=sys.stderr)
            return topic, msg, stamp, header

    def reset(self):
        self.min_values.clear()
        self.max_values.clear()
        super(DepthImagePreview, self).reset()

    def _str_params(self):
        parts = []
        parts.append('new_topic_suffix=%s' % (self.new_topic_suffix,))
        parts.append('keep_orig_image=%r' % (self.keep_orig_image,))
        parts.append('normalize=%r' % (self.normalize,))
        if self.adaptive_normalization:
            parts.append('adaptive_normalization')
        if self.num_normalization_samples != 1:
            parts.append('num_normalization_samples=%r' % (self.num_normalization_samples,))
        if self.fixed_normalization_bounds is not None:
            parts.append('fixed_normalization_bounds=%r' % (self.fixed_normalization_bounds,))
        if self.colormap is not None:
            parts.append('colormap=%r' % (self.colormap,))
        parent_params = super(DepthImagePreview, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


class AnnotationsToDetection2DArray(NoMessageFilter):
    """Convert annotations.xml from CVAT for images to Detection2DArray messages."""

    def __init__(self, annotations_file, topic, frame_id, labels=None):
        """
        :param annotations_file: The file to read subtitles from.
        :param topic: The topic to put the subtitles on.
        :param frame_id: The TF frame of the detection message. Should be the same as the corresponding camera frame.
        :param labels: If not None, only the listed labels will be exported from the annotations file.
        """
        super(AnnotationsToDetection2DArray, self).__init__()
        self._annotations_file = annotations_file
        self._topic = topic
        self._frame_id = frame_id
        self._labels = TopicSet(labels) if labels is not None else None

        self._stamp_regex = re.compile(r'([0-9]+\.[0-9]{9})')

        self._annotations = None

    def on_filtering_start(self):
        super(AnnotationsToDetection2DArray, self).on_filtering_start()
        annotations_file = self.resolve_file(self._annotations_file)
        self._annotations = self._parse_annotations(annotations_file)
        print("Read", len(self._annotations), "annotations from file", annotations_file)

    def _parse_annotations(self, annotations_file):
        annotations = None
        if annotations_file.endswith(".xml"):
            annotations = self._parse_xml(annotations_file)

        if annotations is None:
            raise RuntimeError("Unsupported annotations type. Only CVAT XML is supported so far.")
        return annotations

    def _parse_xml(self, annotations_file):
        parser = etree.XMLParser(remove_comments=True, encoding='utf-8')
        tree = etree.parse(annotations_file, parser).getroot()

        if tree.tag == "annotations" and len(tree) > 0:
            if tree[0].tag == "version" and tree[0].text == "1.1":
                return self._parse_cvat_xml(tree)

        return None

    def _parse_cvat_xml(self, xml_root):
        # type: (etree._Element) -> Dict[rospy.Time, Detection2DArray]
        annotations = dict()

        warned = False

        for child in xml_root:
            if child.tag != "image":
                continue
            name = child.get("name")
            match = re.match(self._stamp_regex, name)
            if not match:
                raise RuntimeError("Could not parse timestamp from item name")
            stamp_str = match.group(1)
            stamp = rospy.Time(*[int(i, base=10) for i in stamp_str.split(".")])

            msg = Detection2DArray()
            msg.header.frame_id = self._frame_id
            msg.header.stamp = stamp
            for det_elem in child:
                if det_elem.tag == "box":
                    if self._labels is not None and det_elem.get("label") not in self._labels:
                        continue
                    det = Detection2D()
                    det.header = msg.header
                    tl = (float(det_elem.get("xtl")), float(det_elem.get("ytl")))
                    br = (float(det_elem.get("xbr")), float(det_elem.get("ybr")))
                    det.bbox.center.x = (br[0] + tl[0]) / 2
                    det.bbox.center.y = (br[1] + tl[1]) / 2
                    det.bbox.size_x = br[0] - tl[0]
                    det.bbox.size_y = br[1] - tl[1]
                    msg.detections.append(det)
                else:
                    if not warned:
                        warned = True
                        print("Found unsupported annotation type", det_elem.tag, file=sys.stderr)

            if len(msg.detections) > 0:
                annotations[stamp] = msg

        return annotations

    def extra_initial_messages(self):
        connection_header = create_connection_header(self._topic, String)
        for stamp, msg in self._annotations.items():
            yield self._topic, msg, stamp, connection_header, {MessageTags.GENERATED}

    def _str_params(self):
        params = ["topic=" + self._topic, "annotations_file=" + self.resolve_file(self._annotations_file)]
        if self._labels is not None:
            params.append("labels=%s" % (str(self._labels),))
        parent_params = super(AnnotationsToDetection2DArray, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)


class BlurDetections(ImageTransportFilter):
    """Blur parts of images that are covered by a detection bounding box. The header timestamp of both image and
    detection messages are used for matching. Exact match is required."""

    def __init__(self, image_topic, detections_topic, is_raw=True, ellipse=False, blur_factor=2, bbox_scale=1.0,
                 detection_receive_stamp_from_image=False, approximate_sync_threshold=None, add_tags=None,
                 *args, **kwargs):
        """
        :param str image_topic: Topic with images.
        :param str detections_topic: Topic with detections.
        :param bool is_raw: Whether the filter works on raw or deserialized messages.
        :param bool ellipse: Whether to blur an inscribed ellipse (otherwise a rectangle is blurred).
        :param float blur_factor: Parameter of the blurring. The number of pixels in each dimension that the detection
                                  will be reduced to.
        :param float bbox_scale: Scale of the detection bounding boxes used for blurring.
        :param bool detection_receive_stamp_from_image: If True, the detection messages' receive timestamp will be set
                                                        to the same stamp the images have (so that they are nicely
                                                        aligned in rqt_bag and such tools).
                                                        max_detection_delay after the detection's header stamp.
        :param float approximate_sync_threshold: If not None, defines the maximum difference between stamps of a matched
                                                 detection-image pair. If None, only exact matches are considered.
        :param add_tags: Tags to be added to messages with some blurred detections.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        include_topics = [image_topic]
        include_types = None
        if detection_receive_stamp_from_image or approximate_sync_threshold is not None:
            include_topics.append(detections_topic)
            include_types = [Detection2DArray._type]

        super(BlurDetections, self).__init__(is_raw=is_raw, include_topics=include_topics, include_types=include_types,
                                             *args, **kwargs)

        self.image_topic = image_topic
        self.detections_topic = detections_topic
        self.ellipse = ellipse
        self.blur_factor = blur_factor
        self.bbox_scale = bbox_scale
        self.detection_receive_stamp_from_image = detection_receive_stamp_from_image
        self.approximate_sync_threshold = \
            float(approximate_sync_threshold) if approximate_sync_threshold is not None else None
        self.add_tags = set(add_tags) if add_tags else set()

        self._detections_cache = dict()
        self._image_stamps_cache = dict()
        self._approx_detection_stamps = dict()

        self._cv = CvBridge()

    def on_filtering_start(self):
        super(BlurDetections, self).on_filtering_start()
        bag = self._multibag if self._multibag is not None else self._bag

        topic = self.detections_topic.lstrip('/')
        topics = [topic, '/' + topic]

        if self.detection_receive_stamp_from_image or self.approximate_sync_threshold is not None:
            topic = self.image_topic.lstrip('/')
            topics += [topic, '/' + topic]

        for topic, msg, stamp in bag.read_messages(topics=topics, raw=True):
            datatype = msg[0]
            if datatype == Detection2DArray._type:
                det_data = msg[1]
                header = deserialize_header(det_data, Detection2DArray)
                self._detections_cache[header.stamp] = det_data, header.stamp
            else:
                img_data = msg[1]
                pytype = msg[-1]
                header = deserialize_header(img_data, pytype)
                self._image_stamps_cache[header.stamp] = stamp

        # In approximate mode, we find the closest detection for each image that is within the proximity bounds
        if self.approximate_sync_threshold is not None:
            img_stamps = list(sorted(self._image_stamps_cache.keys()))
            for stamp, det_data_and_stamp in list(self._detections_cache.items()):
                if stamp not in self._image_stamps_cache:
                    idx = bisect_left(img_stamps, stamp)
                    for i in sorted(range(-10, 10), key=lambda x: abs(x)):
                        if 0 <= idx + i < len(img_stamps):
                            img_stamp = img_stamps[idx + i]
                            if abs((img_stamp - stamp).to_sec()) <= self.approximate_sync_threshold:
                                if stamp not in self._approx_detection_stamps:
                                    self._approx_detection_stamps[stamp] = img_stamp
                                elif abs((self._approx_detection_stamps[stamp] - stamp).to_sec()) > \
                                        abs((stamp - img_stamp).to_sec()):
                                    self._approx_detection_stamps[stamp] = img_stamp

                            if img_stamp not in self._detections_cache:
                                self._detections_cache[img_stamp] = det_data_and_stamp
                            else:
                                det_data, det_stamp = self._detections_cache[img_stamp]
                                if abs((img_stamp - det_stamp).to_sec()) > abs((img_stamp - stamp).to_sec()):
                                    self._detections_cache[img_stamp] = det_data_and_stamp

        print("Preloaded %i detections for topic %s" % (len(self._detections_cache), self.detections_topic))
        if self.detection_receive_stamp_from_image:
            print("Preloaded %i image stamps for topic %s" % (len(self._image_stamps_cache), self.image_topic))
        if len(self._approx_detection_stamps) > 0:
            print("Matched %i approximate timestamps" % (len(self._approx_detection_stamps),))

    def filter_raw(self, topic, datatype, data, md5sum, pytype, stamp, conn_header, tags):
        if datatype == Config._type:
            self.process_transport_params(topic, raw_to_msg(datatype, data, md5sum, pytype))
            return topic, datatype, data, md5sum, pytype, stamp, conn_header, tags

        header = deserialize_header(data, pytype)

        if datatype == Detection2DArray._type and header.stamp in self._approx_detection_stamps:
            header.stamp = self._approx_detection_stamps[header.stamp]
            _, ser_header, _, _ = msg_to_raw(header)
            data = ser_header + data[len(ser_header):]

        if self.detection_receive_stamp_from_image and datatype == Detection2DArray._type:
            stamp = self._image_stamps_cache.get(header.stamp, stamp)
            return topic, datatype, data, md5sum, pytype, stamp, conn_header, tags

        if header.stamp not in self._detections_cache:
            return topic, datatype, data, md5sum, pytype, stamp, conn_header, tags

        img = raw_to_msg(datatype, data, md5sum, pytype)
        result = self.filter_any_image(topic, img, stamp, conn_header, tags)
        result = normalize_filter_result(result)
        del self._detections_cache[header.stamp]

        raw_result = []
        for _result in result:
            if _result is not None:
                _topic, _msg, _stamp, _header, _tags = _result
                _datatype, _data, _md5sum, _pytype = msg_to_raw(_msg)
                raw_result.append((_topic, _datatype, _data, _md5sum, _pytype, _stamp, _header, _tags))
            else:
                raw_result.append(None)

        return raw_result

    def filter_deserialized(self, topic, msg, stamp, conn_header, tags):
        if msg._type == Config._type:
            self.process_transport_params(topic, msg)
            return topic, msg, stamp, conn_header, tags

        header = msg.header

        if msg._type == Detection2DArray._type and header.stamp in self._approx_detection_stamps:
            header.stamp = self._approx_detection_stamps[header.stamp]

        if self.detection_receive_stamp_from_image and msg._type == Detection2DArray._type:
            stamp = self._image_stamps_cache.get(header.stamp, stamp)
            return topic, msg, stamp, conn_header, tags

        if header.stamp not in self._detections_cache:
            return topic, msg, stamp, conn_header, tags

        result = self.filter_any_image(topic, msg, stamp, conn_header, tags)
        del self._detections_cache[header.stamp]
        return result

    def filter_image(self, topic, orig_msg, img_msg, raw_topic, transport, stamp, header, tags):
        det_msg = Detection2DArray().deserialize(self._detections_cache[img_msg.header.stamp][0])
        if len(det_msg.detections) > 0:
            cv_img = self._cv.imgmsg_to_cv2(img_msg).copy()

            for det in det_msg.detections:
                self.blur(cv_img, det)

            img_msg = self._cv.cv2_to_imgmsg(cv_img, encoding=img_msg.encoding, header=img_msg.header)
            tags = tags_for_changed_msg(tags, self.add_tags)

        return topic, orig_msg, img_msg, raw_topic, transport, stamp, header, tags

    def blur(self, cv_img, det):
        bf = self.blur_factor

        img_h, img_w = cv_img.shape[:2]
        bbox = det.bbox
        w = int(bbox.size_x * self.bbox_scale)
        h = int(bbox.size_y * self.bbox_scale)
        x1 = np.clip(int(bbox.center.x) - w // 2, 0, img_w - 1)
        x2 = np.clip(int(bbox.center.x) + w // 2, 0, img_w - 1)
        y1 = np.clip(int(bbox.center.y) - h // 2, 0, img_h - 1)
        y2 = np.clip(int(bbox.center.y) + h // 2, 0, img_h - 1)

        if x1 == x2 or y1 == y2:
            return
        if x2 < x1:
            x1, x2 = x2, x1
        if y2 < y1:
            y1, y2 = y2, y1

        bfx = min(bf, x2 - x1)
        bfy = min(bf, y2 - y1)
        kernel_size = (int(abs(x2 - x1) // bfx), int(abs(y2 - y1) // bfy))
        blurred_box = cv2.blur(cv_img[y1:y2, x1:x2], kernel_size)

        if self.ellipse and w > 4 and h > 4:
            roibox = cv_img[y1:y2, x1:x2]
            # Get y and x coordinate lists of the "bounding ellipse"
            ey, ex = skimage.draw.ellipse(h // 2, w // 2, h // 2, w // 2, shape=blurred_box.shape)
            roibox[ey, ex] = blurred_box[ey, ex]
            cv_img[y1:y2, x1:x2] = roibox
        else:
            cv_img[y1:y2, x1:x2] = blurred_box

    def reset(self):
        self._detections_cache = dict()
        self._image_stamps_cache = dict()
        self._approx_detection_stamps = dict()
        super(BlurDetections, self).reset()

    def _str_params(self):
        params = []
        params.append("image_topic=" + self.image_topic)
        params.append("detections_topic=" + self.detections_topic)
        params.append("ellipse=%r" % (self.ellipse,))
        params.append("blur_factor=%i" % (self.blur_factor,))
        params.append("bbox_scale=%f" % (self.bbox_scale,))
        parent_params = super(BlurDetections, self)._str_params()
        if len(parent_params) > 0:
            params.append(parent_params)
        return ",".join(params)


class BlurFaces(DeserializedMessageFilter):
    """Blur faces in color/mono images using 'deface' library."""

    def __init__(self, include_types=None, transport_params=None, ignore_transports=None, threshold=0.2, ellipse=True,
                 scale=(960, 960), mask_scale=1.3, replacewith='blur', replaceimg=None, mosaicsize=20, backend='auto',
                 publish_faces=True, add_tags=None, *args, **kwargs):
        """
        :param list include_types: Types of messages to work on. The default is sensor_msgs/Image and CompressedImage.
        :param dict transport_params: Parameters of image transport(s). Keys are transport names (e.g. 'compressed'),
                                      values are the publisher dynamic reconfigure parameters.
        :param list ignore_transports: List of transport names to ignore (defaults to`['compressedDepth']`).
        :param float threshold: Threshold for face detection.
        :param add_tags: Tags to be added to messages with some blurred faces.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        try:
            import deface.centerface
            import deface.deface
        except ImportError:
            cmds = [
                'sudo apt install python3-numpy python3-skimage python3-pil python3-opencv python3-imageio',
                'pip3 install --user --upgrade pip',
            ]
            accel_cmds = list()
            pip_constraints = "'numpy~=1.17.0' 'pillow~=7.0.0' 'imageio~=2.4.0' 'scikit_image~=0.16.0'"
            if os.path.exists('/etc/nv_tegra_release'):
                # from https://elinux.org/Jetson_Zoo#ONNX_Runtime, version compatible with onnx 1.12, Py 3.8 and JP 5
                url = 'https://nvidia.box.com/shared/static/v59xkrnvederwewo2f1jtv6yurl92xso.whl'
                cmds += [
                    "python3 -m pip install --user " + pip_constraints + " 'onnx~=1.12.0'",
                ]
                accel_cmds += [
                    'wget ' + url + ' -O onnxruntime_gpu-1.12.0-cp38-cp38-linux_aarch64.whl',
                    'python3 -m pip install --user onnxruntime_gpu-1.12.0-cp38-cp38-linux_aarch64.whl',
                ]
            else:
                cmds += [
                    "python3 -m pip install --user " + pip_constraints + " 'onnx~=1.12.0'",
                ]
                accel_cmds += [
                    "python3 -m pip install --user 'onnxruntime-openvino~=1.12.0'",
                ]
            cmds += [
                'python3 -m pip install --user --no-deps deface~=1.4.0',
                'python3 -m pip uninstall pip',
            ]

            print('Error importing deface module. Please install it with the following commands:', file=sys.stderr)
            for cmd in cmds:
                print("\t" + cmd, file=sys.stderr)
            print('To add HW acceleration, please run the following commands:', file=sys.stderr)
            for cmd in accel_cmds:
                print("\t" + cmd, file=sys.stderr)

            raise

        default_types = [Image._type, CompressedImage._type, Config._type]

        super(BlurFaces, self).__init__(
            include_types=include_types if include_types is not None else default_types, *args, **kwargs)

        self.transport_params = transport_params if transport_params is not None else {}
        self.received_transport_params = {}
        self.ignore_transports = ['/' + t for t in ignore_transports] if ignore_transports is not None \
            else ['/compressedDepth']
        self.threshold = threshold
        self.ellipse = ellipse
        self.scale = scale
        self.mask_scale = mask_scale
        self.replacewith = replacewith
        self.replaceimg = replaceimg
        self.mosaicsize = mosaicsize
        self.backend = backend
        self.publish_faces = publish_faces
        self.add_tags = add_tags

        self._cv = CvBridge()
        self._centerface = deface.centerface.CenterFace(in_shape=None, backend=self.backend)

    def consider_message(self, topic, datatype, stamp, header, tags):
        if datatype == Config._type:
            return True

        if not super(BlurFaces, self).consider_message(topic, datatype, stamp, header, tags):
            return False

        for t in self.ignore_transports:
            if topic.endswith(t):
                return False

        return True

    def process_transport_params(self, topic, msg, stamp, header, tags):
        if not topic.endswith('/parameter_updates'):
            return topic, msg, stamp, header, tags
        transport_topic, _ = topic.rsplit('/', 1)
        self.received_transport_params[transport_topic] = decode_config(msg)
        return topic, msg, stamp, header, tags

    def filter(self, topic, msg, stamp, header, tags):
        if msg._type == Config._type:
            return self.process_transport_params(topic, msg, stamp, header, tags)

        if msg._type == Image._type:
            raw_msg = msg
            raw_topic = topic
            transport = 'raw'
        else:
            transport = None
            if "/" in topic:
                raw_topic, transport = topic.rsplit("/", 1)
            if transport is None or len(transport) == 0:
                print("Compressed image on a topic without suffix [%s]. Passing message." % (topic,), file=sys.stderr)
                return topic, msg, stamp, header, tags

            raw_msg, err = decode(msg, topic, {})
            if raw_msg is None:
                print('Error converting image: ' + str(err), file=sys.stderr)
                return topic, msg, stamp, header, tags

        enc = raw_msg.encoding
        is_color = isColor(enc) or isMono(enc) or isBayer(enc) or enc == YUV422
        if not is_color:
            return topic, msg, stamp, header, tags

        img = self._cv.imgmsg_to_cv2(raw_msg)
        if raw_msg.encoding == RGB8 or numChannels(raw_msg.encoding) == 1:
            rgb_img = img
        else:
            rgb_img = self._cv.imgmsg_to_cv2(raw_msg, RGB8)

        if self.scale is not None:
            scale = min(
                min(1.0, self.scale[0] / img.shape[0]),
                min(1.0, self.scale[1] / img.shape[1]))
            self._centerface.in_shape = (int(img.shape[0] * scale), int(img.shape[1] * scale))
        else:
            self._centerface.in_shape = None

        dets, _ = self._centerface(rgb_img, threshold=self.threshold)

        import deface.deface

        bad_dets = list()
        for i, det in enumerate(dets):
            boxes, score = det[:4], det[4]
            x1, y1, x2, y2 = boxes.astype(int)
            x1, y1, x2, y2 = deface.deface.scale_bb(x1, y1, x2, y2, self.mask_scale)
            w, h = x2 - x1, y2 - y1
            if w >= 0.2 * img.shape[1] or h >= 0.2 * img.shape[0]:
                bad_dets.append(i)
        dets = np.delete(dets, bad_dets, axis=0)

        if len(dets) == 0:
            return topic, msg, stamp, header, tags

        img = img.copy()

        img_tags = copy.deepcopy(tags)
        img_tags.add(MessageTags.CHANGED)
        if self.add_tags:
            img_tags = img_tags.union(self.add_tags)

        dets_tags = tags_for_generated_msg(tags)

        if deface.__version__ >= "1.5.0":
            deface.deface.anonymize_frame(
                dets, img, self.mask_scale, self.replacewith, self.ellipse, False, self.replaceimg, self.mosaicsize)
        else:
            deface.deface.anonymize_frame(
                dets, img, self.mask_scale, self.replacewith, self.ellipse, False, self.replaceimg)

        raw_msg = self._cv.cv2_to_imgmsg(img, enc, msg.header)
        if msg._type == Image._type:
            msg = raw_msg
        else:
            msg, err = self.get_image_for_transport(msg, raw_msg, topic, transport)

        if not self.publish_faces:
            return topic, msg, stamp, header, img_tags

        dets_msg = self.dets_to_msg(dets, msg)
        dets_header = copy.deepcopy(header)
        dets_header["topic"] = raw_topic + '/anonymized_faces'  # The rest will be fixed by fix_connection_header()

        return [
            (topic, msg, stamp, header, img_tags),
            (dets_header["topic"], dets_msg, stamp, dets_header, dets_tags)
        ]

    def dets_to_msg(self, dets, msg):
        dets_msg = Detection2DArray()
        dets_msg.header = msg.header
        for i, det in enumerate(dets):
            boxes, score = det[:4], det[4]
            x1, y1, x2, y2 = boxes.astype(float)
            s = self.mask_scale - 1.0
            h, w = y2 - y1, x2 - x1
            y1 -= h * s
            y2 += h * s
            x1 -= w * s
            x2 += w * s

            det_msg = Detection2D()
            det_msg.header = dets_msg.header
            det_msg.bbox.center.x = (x1 + x2) / 2.0
            det_msg.bbox.center.y = (y1 + y2) / 2.0
            det_msg.bbox.size_x = w
            det_msg.bbox.size_y = h
            hyp = ObjectHypothesisWithPose()
            hyp.id = i
            hyp.score = score
            hyp.pose.pose.orientation.w = 1.0
            det_msg.results.append(hyp)
            dets_msg.detections.append(det_msg)
        return dets_msg

    def get_image_for_transport(self, msg, raw_msg, compressed_topic, transport):
        config = copy.deepcopy(self.received_transport_params.get(compressed_topic, {}))
        config.update(self.transport_params.get(transport, {}))

        compressed_fmt, compressed_depth_fmt, _ = guess_any_compressed_image_transport_format(msg)

        if compressed_fmt is not None and "format" not in config:
            config["format"] = compressed_fmt.format.value
        elif compressed_depth_fmt is not None and "format" not in config:
            config["format"] = compressed_depth_fmt.format.value

        compressed_msg, err = encode(raw_msg, compressed_topic, config)

        return compressed_msg, err

    def _str_params(self):
        parts = []
        parts.append('threshold=%f' % (self.threshold,))
        parts.append('ellipse=%r' % (self.ellipse,))
        parts.append('scale=%r' % (self.scale,))
        parts.append('mask_scale=%f' % (self.mask_scale,))
        parts.append('replacewith=%s' % (self.replacewith,))
        if self.replaceimg is not None:
            parts.append('replaceimg=%s' % (self.replaceimg,))
        parts.append('mosaicsize=%f' % (self.mosaicsize,))
        parts.append('backend=%s' % (self.backend,))
        parts.append('publish_faces=%r' % (self.publish_faces,))
        if len(self.transport_params) > 0:
            parts.append('transport_params=%r' % (self.transport_params,))
        parent_params = super(BlurFaces, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)


def get_closest_stamp(msg_stamp, q, max_diff):
    closest_stamp = None
    if len(q) > 0:
        diff, closest_stamp = min(sorted([(abs(((s - msg_stamp).to_sec())), s) for s in q]))
        if diff > max_diff:
            closest_stamp = None
    return closest_stamp


class NiftiAugmentJointStates(DeserializedMessageFilter):
    """Add velocity and currents to NIFTi robot joint states."""

    def __init__(self, joint_states_topic, *args, **kwargs):
        """
        :param str joint_states_topic: The topic with joint states to augment.
        :param args: Standard include/exclude and stamp args.
        :param kwargs: Standard include/exclude and stamp kwargs.
        """
        super(NiftiAugmentJointStates, self).__init__(
            include_topics=[joint_states_topic, 'currents', 'flippers_vel'], *args, **kwargs)

        self._currents_queue = {}
        self._flippers_vels_queue = {}

        self._max_wait_time = rospy.Duration(2.0)
        self._requeue_delay = rospy.Duration(0.1)
        self._max_sync_diff = rospy.Duration(0.1).to_sec()
        self._max_requeues = 20

        self._joint_names_to_fields = {
            'front_left_flipper_j': 'frontLeft',
            'front_right_flipper_j': 'frontRight',
            'rear_left_flipper_j': 'rearLeft',
            'rear_right_flipper_j': 'rearRight',
        }

    def filter(self, topic, msg, stamp, header, tags):
        # Prune stale queue items
        self._currents_queue = {s: m for s, m in self._currents_queue.items() if s + self._max_wait_time >= stamp}
        self._flippers_vels_queue = \
            {s: m for s, m in self._flippers_vels_queue.items() if s + self._max_wait_time >= stamp}

        t = topic.lstrip('/')
        msg_stamp = msg.header.stamp

        if t == 'currents':
            self._currents_queue[msg_stamp] = msg
            return topic, msg, stamp, header, tags
        elif t == 'flippers_vel':
            self._flippers_vels_queue[msg_stamp] = msg
            return topic, msg, stamp, header, tags

        currents_stamp = get_closest_stamp(msg_stamp, self._currents_queue, self._max_sync_diff)
        vel_stamp = get_closest_stamp(msg_stamp, self._flippers_vels_queue, self._max_sync_diff)

        if currents_stamp is not None and vel_stamp is not None:
            currents = self._currents_queue[currents_stamp]
            flippers_vel = self._flippers_vels_queue[vel_stamp]
            del self._currents_queue[currents_stamp]
            del self._flippers_vels_queue[vel_stamp]

            msg.velocity = [0.0] * 6
            msg.effort = [0.0] * 6

            for joint, field in self._joint_names_to_fields.items():
                idx = msg.name.index(joint)
                msg.velocity[idx] = getattr(flippers_vel, field)
                msg.effort[idx] = getattr(currents, field)

            tags = tags_for_changed_msg(tags)

            if MessageTags.REQUEUE in tags:
                num_requeues = num_used_requeues(tags, self._max_requeues)
                tags = cleanup_requeue_tags(tags)
                # correct the receive stamp back to its original value
                stamp -= self._requeue_delay * num_requeues
        else:
            # requeue the joint states message until we have the other two;
            # pass the unaugmented message if the messages are not available even afte the requeues
            if MessageTags.REQUEUE not in tags:
                tags = tags_for_requeuing_msg(tags, max_requeues=self._max_requeues, drop_on_timeout=False)
            stamp += self._requeue_delay

        return topic, msg, stamp, header, tags

    def _str_params(self):
        parts = []
        parent_params = super(NiftiAugmentJointStates, self)._str_params()
        if len(parent_params) > 0:
            parts.append(parent_params)
        return ", ".join(parts)

    def reset(self):
        self._currents_queue = {}
        self._flippers_vels_queue = {}
