#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

# Test helper for testing message filter

from functools import partial
import rospy
import rostest
import unittest
import time
import sys

from geometry_msgs.msg import Vector3Stamped
from std_msgs.msg import String

NAME = "test_filter"


class Filter(unittest.TestCase):
    def __init__(self, *args):
        super(Filter, self).__init__(*args)
        rospy.init_node(NAME)

        self.start_timeout = float(rospy.get_param("~start_timeout", 5))
        self.test_duration = float(rospy.get_param("~test_duration", 5))

        topics = (
            "node_out", "all_in_out", "all_inout_out", "none_out", "half_out", "import_out", "exception_out",
            "syntax_out", "return_out", "header_out", "header_wrong_out", "raw_out", "wrong_out"
        )
        self.msg_counts = dict()
        self.subs = list()

        for t in topics:
            self.msg_counts[t] = 0
            self.subs.append(rospy.Subscriber(t, rospy.AnyMsg, partial(self.count_msg, t)))

        try:
            rospy.wait_for_message("in", Vector3Stamped, self.start_timeout)
        except rospy.ROSException:
            self.fail("Something did not work, the test received no messages on topic 'in'")

        time.sleep(self.test_duration)

        # Subscribe late to be sure that the publishing has already happened some time in the past
        self.msg_counts["latch_out"] = 0
        self.subs.append(rospy.Subscriber("latch_out", rospy.AnyMsg, partial(self.count_msg, "latch_out")))

    def count_msg(self, topic, _):
        self.msg_counts[topic] += 1

    def test_filter(self):
        self.assertGreater(self.msg_counts["node_out"], 1)
        self.assertGreater(self.msg_counts["all_in_out"], 1)
        self.assertGreater(self.msg_counts["all_inout_out"], 1)
        self.assertEqual(self.msg_counts["none_out"], 0)
        # It should be a half, but we need to give some slack to the check as the subs are not synchronized
        self.assertGreater(self.msg_counts["half_out"], self.msg_counts["node_out"] / 3)
        self.assertGreater(self.msg_counts["import_out"], self.msg_counts["node_out"] / 3)
        self.assertEqual(self.msg_counts["exception_out"], 0)
        self.assertEqual(self.msg_counts["return_out"], 0)
        self.assertGreater(self.msg_counts["header_out"], self.msg_counts["node_out"] / 3)
        self.assertEqual(self.msg_counts["header_wrong_out"], 0)
        self.assertGreater(self.msg_counts["raw_out"], 1)
        self.assertEqual(self.msg_counts["wrong_out"], 0)

    def test_latch(self):
        try:
            rospy.wait_for_message("latch_out", rospy.AnyMsg, timeout=self.test_duration)
        except rospy.ROSException:
            self.fail("Did not receive anything on the latched topic")
        time.sleep(0.1)
        self.assertEqual(self.msg_counts["latch_out"], 1)


if __name__ == "__main__":
    try:
        rostest.run('rostest', NAME, Filter, sys.argv)
    except KeyboardInterrupt:
        pass
    print("exiting")
