#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

# Test helper for testing message throttle

from functools import partial
import rospy
import rostest
import unittest
import time
import sys

from std_msgs.msg import String

NAME = "test_throttle"


class Throttle(unittest.TestCase):
    def __init__(self, *args):
        super(Throttle, self).__init__(*args)
        rospy.init_node(NAME)

        self.start_timeout = float(rospy.get_param("~start_timeout", 10))
        self.test_duration = float(rospy.get_param("~test_duration", 10))

        topics = (
            "throttle_input/output", "node_in_throttle", "node_out", "throttle_input_lazy/output",
            "node_in_lazy_throttle", "node_out_lazy", "node_out_param", "throttle_method/output",
            "node_out_invalid_rate"
        )
        self.msg_counts = dict()
        self.subs = list()
        self.start = None

        for t in topics:
            self.msg_counts[t] = 0
            self.subs.append(rospy.Subscriber(t, rospy.AnyMsg, partial(self.count_msg, t)))

        try:
            rospy.wait_for_message("node_in", String, self.start_timeout)
        except rospy.ROSException:
            self.fail("Something did not work, the test received no messages on topic 'node_in'")

        self.start = time.time()
        self.end = self.start + self.test_duration

        # Count messages for the specified duration (and a bit longer, we filter the exact times in callback
        time.sleep(self.test_duration + 1)

    def count_msg(self, topic, _):
        if self.start is None or not self.start < time.time() < self.end:
            return
        self.msg_counts[topic] += 1

    def get_freq(self, topic):
        return self.msg_counts[topic] / self.test_duration

    def test_throttle(self):
        self.assertAlmostEqual(self.get_freq("throttle_input/output"), 13, delta=2)
        self.assertAlmostEqual(self.get_freq("node_in_throttle"), 13, delta=2)
        self.assertAlmostEqual(self.get_freq("node_out"), 13, delta=2)
        self.assertAlmostEqual(self.get_freq("throttle_input_lazy/output"), 13, delta=2)
        self.assertAlmostEqual(self.get_freq("node_in_lazy_throttle"), 13, delta=2)
        self.assertAlmostEqual(self.get_freq("node_out_lazy"), 13, delta=2)
        self.assertAlmostEqual(self.get_freq("node_out_param"), 13, delta=2)
        # Yep, THROTTLE cannot make 20 Hz to 13 Hz correctly.
        self.assertAlmostEqual(self.get_freq("throttle_method/output"), 10, delta=2)
        # Default is 1 Hz
        self.assertAlmostEqual(self.get_freq("node_out_invalid_rate"), 1, delta=0.25)


if __name__ == "__main__":
    try:
        rostest.run('rostest', NAME, Throttle, sys.argv)
    except KeyboardInterrupt:
        pass
    print("exiting")
