#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

# Test helper for testing message repeater

from functools import partial
import rospy
import rostest
import unittest
import time
import sys

from geometry_msgs.msg import PointStamped
from std_msgs.msg import String

NAME = "test_repeat"


class Repeat(unittest.TestCase):
    def __init__(self, *args):
        super(Repeat, self).__init__(*args)
        rospy.init_node(NAME)

        self.start_timeout = float(rospy.get_param("~start_timeout", 5))
        self.test_duration = float(rospy.get_param("~test_duration", 10))

        self.msg_counts = dict()

    def count_msg(self, topic, _):
        if self.start is None or not self.start < time.time() < self.end:
            return
        self.msg_counts[topic] += 1

    def get_freq(self, topic):
        return self.msg_counts[topic] / self.test_duration

    def test_repeat(self):
        topics = (
            "repeat_input/output", "node_in_repeat", "node_out", "repeat_input_lazy/output", "node_in_lazy_repeat",
            "node_out_lazy", "node_out_param", "node_out_invalid_rate", "node_out_fast", "node_out_fast_reset_lower",
            "node_out_fast_reset_higher", "node_out_fast_timer_reset_lower", "node_out_fast_timer_reset_higher"
        )

        subs = list()
        self.start = None

        for t in topics:
            self.msg_counts[t] = 0
            subs.append(rospy.Subscriber(t, PointStamped, partial(self.count_msg, t)))

        try:
            rospy.wait_for_message("node_in", PointStamped, self.start_timeout)
        except rospy.ROSException:
            self.fail("Something did not work, the test received no messages on topic 'node_in'")

        self.start = time.time()
        self.end = self.start + self.test_duration

        # Count messages for the specified duration (and a bit longer, we filter the exact times in callback
        time.sleep(self.test_duration + 1)

        self.assertAlmostEqual(self.get_freq("repeat_input/output"), 13.0, delta=3.0)
        self.assertAlmostEqual(self.get_freq("node_in_repeat"), 13.0, delta=3.0)
        self.assertAlmostEqual(self.get_freq("node_out"), 13.0, delta=3.0)
        self.assertAlmostEqual(self.get_freq("repeat_input_lazy/output"), 13.0, delta=3.0)
        self.assertAlmostEqual(self.get_freq("node_in_lazy_repeat"), 13.0, delta=3.0)
        self.assertAlmostEqual(self.get_freq("node_out_lazy"), 13.0, delta=3.0)
        self.assertAlmostEqual(self.get_freq("node_out_param"), 13.0, delta=3.0)
        # Default rate is 1.0
        self.assertAlmostEqual(self.get_freq("node_out_invalid_rate"), 1.0, delta=0.5)
        self.assertAlmostEqual(self.get_freq("node_out_fast"), 20.0, delta=4.0)
        # Timer should fire each 0.2 s, but each 0.1 s a message arrives and resets it.
        self.assertAlmostEqual(self.get_freq("node_out_fast_reset_lower"), 10.0, delta=3.0)
        # Timer should fire each 0.06 s and each 0.1 s a message arrives and resets it. So we +- sum the rates.
        self.assertAlmostEqual(self.get_freq("node_out_fast_reset_higher"), 22.5, delta=5.5)
        self.assertEqual(self.get_freq("node_out_fast_timer_reset_lower"), 0)
        self.assertAlmostEqual(self.get_freq("node_out_fast_timer_reset_higher"), 90.0, delta=8.0)

    def always_count_msg(self, topic, _):
        self.msg_counts[topic] += 1

    def wait_until_num_msgs(self, num, deadline, topic):
        while self.msg_counts[topic] < num and time.time() < deadline:
            time.sleep(0.01)

    def test_discard_older(self):
        topic = "node_out_discard_older"
        self.msg_counts[topic] = 0

        sub = rospy.Subscriber(topic, rospy.AnyMsg, partial(self.always_count_msg, topic), queue_size=100)
        pub = rospy.Publisher("node_in_discard_older", PointStamped, latch=True, queue_size=1)
        reset_pub = rospy.Publisher("repeat_discard_older/reset", PointStamped, queue_size=1)

        msg = PointStamped()
        msg.header.stamp = rospy.Time(10)
        pub.publish(msg)

        wallclock_timeout = time.time() + self.test_duration
        self.wait_until_num_msgs(1, wallclock_timeout, topic)
        self.assertEqual(self.msg_counts[topic], 1)

        msg.header.stamp = rospy.Time(9)
        pub.publish(msg)

        time.sleep(0.1)
        self.assertEqual(self.msg_counts[topic], 1)

        # Message is newer, should be published

        msg.header.stamp = rospy.Time(11)
        pub.publish(msg)

        self.wait_until_num_msgs(2, min(wallclock_timeout, time.time() + 0.1), topic)
        self.assertEqual(self.msg_counts[topic], 2)

        # Message is newer, should be published

        msg.header.stamp = rospy.Time(11.1)
        pub.publish(msg)

        self.wait_until_num_msgs(3, min(wallclock_timeout, time.time() + 0.1), topic)
        self.assertEqual(self.msg_counts[topic], 3)

        # Message is older, should not be published

        msg.header.stamp = rospy.Time(11)
        pub.publish(msg)

        time.sleep(0.1)
        self.assertEqual(self.msg_counts[topic], 3)

        # Resetting the repeater should discard the previous message, so the new one can be published even if it is
        # older.

        reset_pub.publish(msg)
        time.sleep(0.1)

        msg.header.stamp = rospy.Time(11)
        pub.publish(msg)

        self.wait_until_num_msgs(4, min(wallclock_timeout, time.time() + 0.1), topic)
        self.assertEqual(self.msg_counts[topic], 4)

        # Message is newer, should be published

        msg.header.stamp = rospy.Time(11.1)
        pub.publish(msg)

        self.wait_until_num_msgs(5, min(wallclock_timeout, time.time() + 0.1), topic)
        self.assertEqual(self.msg_counts[topic], 5)

    def test_max_age(self):
        topic = "node_out_max_age"
        self.msg_counts[topic] = 0

        sub = rospy.Subscriber(topic, rospy.AnyMsg, partial(self.always_count_msg, topic), queue_size=100)
        pub = rospy.Publisher("node_in_max_age", PointStamped, latch=True, queue_size=1)

        max_age = 0.1
        min_msgs = 2
        max_msgs = 3

        msg = PointStamped()
        msg.header.stamp = rospy.Time.now()

        # Publish first message to the repeater so that it starts publishing
        pub.publish(msg)

        # Wait until messages start coming
        wallclock_timeout = time.time() + self.test_duration
        self.wait_until_num_msgs(1, wallclock_timeout, topic)

        # Publish a new message with current timestamp

        self.msg_counts[topic] = 0
        msg.header.stamp = rospy.Time.now()
        pub.publish(msg)

        # Wait for some time and count the messages
        time.sleep(max_age * 10)

        self.assertGreaterEqual(self.msg_counts[topic], min_msgs)
        self.assertLessEqual(self.msg_counts[topic], max_msgs)

        # Try to publish a message that is already too old at the time of publication

        self.msg_counts[topic] = 0
        msg.header.stamp = rospy.Time.now()
        time.sleep(max_age * 1.1)
        pub.publish(msg)

        time.sleep(max_age * 10)

        self.assertGreaterEqual(self.msg_counts[topic], 0)
        self.assertLessEqual(self.msg_counts[topic], 0)

        # Publish a new message with current timestamp (again, just to test things did not break before)

        self.msg_counts[topic] = 0
        msg.header.stamp = rospy.Time.now()
        pub.publish(msg)

        # Wait for some time and count the messages
        time.sleep(max_age * 10)

        self.assertGreaterEqual(self.msg_counts[topic], min_msgs)
        self.assertLessEqual(self.msg_counts[topic], max_msgs)

    def test_max_repeats(self):
        topic = "node_out_max_repeats"
        self.msg_counts[topic] = 0

        sub = rospy.Subscriber(topic, rospy.AnyMsg, partial(self.always_count_msg, topic), queue_size=100)
        pub = rospy.Publisher("node_in_max_repeats", String, latch=True, queue_size=1)
        reset_pub = rospy.Publisher("reset", String, queue_size=1)

        min_msgs = 10
        max_msgs = 16
        test_duration = 2.0

        # Publish first message to the repeater so that it starts publishing
        pub.publish(String())

        # Wait until messages start coming
        wallclock_timeout = time.time() + self.test_duration
        self.wait_until_num_msgs(1, wallclock_timeout, topic)

        # Wait for the prescribed time and count the messages
        time.sleep(test_duration)

        self.assertGreaterEqual(self.msg_counts[topic], min_msgs)
        self.assertLessEqual(self.msg_counts[topic], max_msgs)

        # Reset the counter, publish a new message - this should allow the repeater to start sending again

        self.msg_counts[topic] = 0
        pub.publish(String())

        time.sleep(test_duration)

        self.assertGreaterEqual(self.msg_counts[topic], min_msgs)
        self.assertLessEqual(self.msg_counts[topic], max_msgs)

        # Publish to the reset topic should not result in sending data again (the last message is reset, too)

        self.msg_counts[topic] = 0
        reset_pub.publish(String())

        time.sleep(test_duration)

        self.assertEqual(self.msg_counts[topic], 0)

        # Now publish a new message

        self.msg_counts[topic] = 0
        pub.publish(String())

        time.sleep(test_duration)

        self.assertGreaterEqual(self.msg_counts[topic], min_msgs)
        self.assertLessEqual(self.msg_counts[topic], max_msgs)


if __name__ == "__main__":
    try:
        rostest.run('rostest', NAME, Repeat, sys.argv)
    except KeyboardInterrupt:
        pass
    print("exiting")
