#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

# Test helper for testing change_header

import rospy
import rostest
import unittest
import time
import sys

from geometry_msgs.msg import PointStamped
from std_msgs.msg import String

NAME = "test_change_header"


orig_msg = PointStamped()
orig_msg.header.stamp.secs = 11
orig_msg.header.frame_id = "test"
orig_msg.point.x = 1.0
orig_msg.point.y = 2.0
orig_msg.point.z = 3.0


class ChangeHeader(unittest.TestCase):
    def __init__(self, *args):
        super(ChangeHeader, self).__init__(*args)
        rospy.init_node(NAME)

        self.timeout = float(rospy.get_param("~timeout", 5))

    def test_input(self):
        msg = rospy.wait_for_message("change_header_input/output", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_non_lazy(self):
        msg = rospy.wait_for_message("node_in_change_header", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_explicit(self):
        msg = rospy.wait_for_message("node_out", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_input_lazy(self):
        msg = rospy.wait_for_message("change_header_input_lazy/output", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_lazy(self):
        msg = rospy.wait_for_message("node_in_lazy_change_header", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_explicit_lazy(self):
        msg = rospy.wait_for_message("node_out_lazy", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_header_absolute(self):
        msg = rospy.wait_for_message("node_out_absolute", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp.secs, 10)
        self.assertEqual(msg.header.stamp.nsecs, 0)
        self.assertEqual(msg.header.frame_id, "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_stamp_ros_time(self):
        msg = rospy.wait_for_message("node_out_ros_time", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertAlmostEqual(msg.header.stamp.to_sec(), rospy.Time.now().to_sec(), delta=5)
        self.assertEqual(msg.header.frame_id, "test")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_stamp_ros_time_relative(self):
        msg = rospy.wait_for_message("node_out_ros_time_rel", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertAlmostEqual(msg.header.stamp.to_sec(), rospy.Time.now().to_sec() + 10, delta=5)
        self.assertEqual(msg.header.frame_id, "test")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_stamp_wall_time(self):
        msg = rospy.wait_for_message("node_out_wall_time", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertAlmostEqual(msg.header.stamp.to_sec(), time.time(), delta=5)
        self.assertEqual(msg.header.frame_id, "test")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_stamp_wall_time_relative(self):
        msg = rospy.wait_for_message("node_out_wall_time_rel", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertAlmostEqual(msg.header.stamp.to_sec(), time.time() + 10, delta=5)
        self.assertEqual(msg.header.frame_id, "test")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_header_prefix(self):
        msg = rospy.wait_for_message("node_out_prefix", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp.secs, orig_msg.header.stamp.secs - 10)
        self.assertEqual(msg.header.stamp.nsecs, orig_msg.header.stamp.nsecs)
        self.assertEqual(msg.header.frame_id, "abcd" + orig_msg.header.frame_id)
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_header_suffix(self):
        msg = rospy.wait_for_message("node_out_suffix", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp.secs, orig_msg.header.stamp.secs + 10)
        self.assertEqual(msg.header.stamp.nsecs, orig_msg.header.stamp.nsecs)
        self.assertEqual(msg.header.frame_id, orig_msg.header.frame_id + "abcd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_replace_header_start(self):
        msg = rospy.wait_for_message("node_out_replace_start", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "asdest")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_replace_header_end(self):
        msg = rospy.wait_for_message("node_out_replace_end", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "tesasd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_replace_header(self):
        msg = rospy.wait_for_message("node_out_replace", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, "asdesasd")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_nothing(self):
        msg = rospy.wait_for_message("node_out_nothing", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp, orig_msg.header.stamp)
        self.assertEqual(msg.header.frame_id, orig_msg.header.frame_id)
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_no_header(self):
        try:
            rospy.wait_for_message("node_out_no_header", String, self.timeout)
            self.fail("No message was expected, but something was received")
        except rospy.ROSException:
            pass

    def test_negative_stamp(self):
        msg = rospy.wait_for_message("node_out_negative_stamp", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp.secs, 0)
        self.assertEqual(msg.header.stamp.nsecs, 1)
        self.assertEqual(msg.header.frame_id, orig_msg.header.frame_id)
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_all(self):
        msg = rospy.wait_for_message("node_out_all", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp.secs, 10)
        self.assertEqual(msg.header.stamp.nsecs, 0)
        self.assertEqual(msg.header.frame_id, "cras")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)

    def test_all_relative(self):
        msg = rospy.wait_for_message("node_out_all_relative", PointStamped, self.timeout)
        assert isinstance(msg, PointStamped)

        self.assertEqual(msg.header.stamp.secs, 9)
        self.assertEqual(msg.header.stamp.nsecs, 0)
        self.assertEqual(msg.header.frame_id, "cbcteststs")
        self.assertEqual(msg.point.x, orig_msg.point.x)
        self.assertEqual(msg.point.y, orig_msg.point.y)
        self.assertEqual(msg.point.z, orig_msg.point.z)


if __name__ == "__main__":
    try:
        rostest.run('rostest', NAME, ChangeHeader, sys.argv)
    except KeyboardInterrupt:
        pass
    print("exiting")
