#!/usr/bin/env python

import argparse
import sys

import rosgraph
import rospy
import rostopic


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--subscriber-host', '-s',
        default=None,
        help='Subscriber hostname: (default: all)'
    )
    parser.add_argument(
        '--publisher-host', '-p',
        default=None,
        help='Publisher hostname: (default: all)'
    )
    parser.add_argument(
        '--publisher-host-sort', action='store_true',
        help='Sort by publisher host name or not (default: False)'
    )
    parser.add_argument(
        '--include-rosout', action='store_true',
        help='Include /rosout topic or not: (default: False)'
    )
    parser.add_argument(
        '--show-nodes', action='store_true',
        help='Show node names or not: (default: False)'
    )
    args = parser.parse_args()
    subscriber_host = args.subscriber_host
    publisher_host = args.publisher_host
    publisher_host_sort = args.publisher_host_sort
    include_rosout = args.include_rosout
    show_nodes = args.show_nodes

    rospy.init_node('rostopic_connection_list')
    master = rosgraph.Master('/rostopic')
    pubs, subs = rostopic.get_topic_list(master)
    host_pub_topics, host_sub_topics = rostopic._rostopic_list_group_by_host(
        master, pubs, subs)
    hostnames = set(
        list(host_pub_topics.keys()) + list(host_sub_topics.keys()))
    topic_names = []
    pub_node_dicts = {}
    sub_node_dicts = {}

    for hostname in hostnames:
        if hostname in host_pub_topics:
            for pub_topic in host_pub_topics[hostname]:
                topic_name = pub_topic[0]
                if not include_rosout and topic_name == '/rosout':
                    continue
                pub_node_names = pub_topic[2]
                topic_names.append(topic_name)
                if topic_name in pub_node_dicts:
                    pub_node_dicts[topic_name][hostname] = pub_node_names
                else:
                    pub_node_dicts[topic_name] = {
                        hostname: pub_node_names
                    }
        if hostname in host_sub_topics:
            for sub_topic in host_sub_topics[hostname]:
                topic_name = sub_topic[0]
                if not include_rosout and topic_name == '/rosout':
                    continue
                sub_node_names = sub_topic[2]
                topic_names.append(topic_name)
                if topic_name in sub_node_dicts:
                    sub_node_dicts[topic_name][hostname] = sub_node_names
                else:
                    sub_node_dicts[topic_name] = {
                        hostname: sub_node_names
                    }

    topic_names = set(topic_names)
    if publisher_host_sort:
        host_pub_node_dict = {}
        for topic_name in topic_names:
            if topic_name not in pub_node_dicts:
                continue
            if topic_name not in sub_node_dicts:
                continue
            pub_node_dict = pub_node_dicts[topic_name]
            sub_node_dict = sub_node_dicts[topic_name]

            pub_hostnames = pub_node_dict.keys()
            for pub_hostname in pub_hostnames:
                sub_hostnames = [
                    x for x in sub_node_dict.keys() if x != pub_hostname
                ]
                if len(sub_hostnames) == 0:
                    continue
                if pub_hostname not in host_pub_node_dict:
                    host_pub_node_dict[pub_hostname] = {}
                for sub_hostname in sub_hostnames:
                    pub_node_names = pub_node_dict[pub_hostname]
                    sub_node_names = sub_node_dict[sub_hostname]
                    if sub_hostname in host_pub_node_dict[pub_hostname]:
                        host_pub_node_dict[pub_hostname][sub_hostname].append((
                            topic_name,
                            pub_node_names,
                            sub_node_names,
                        ))
                    else:
                        host_pub_node_dict[pub_hostname][sub_hostname] = [(
                            topic_name,
                            pub_node_names,
                            sub_node_names,
                        )]

        pub_hostnames = sorted(host_pub_node_dict.keys())
        for pub_hostname in pub_hostnames:
            if (publisher_host is not None
                    and pub_hostname != publisher_host):
                continue
            print('Publisher host: {}'.format(pub_hostname))
            sub_dicts = host_pub_node_dict[pub_hostname]
            sub_hostnames = sorted(sub_dicts.keys())
            for sub_hostname in sub_hostnames:
                if (subscriber_host is not None
                        and sub_hostname != subscriber_host):
                    continue
                data = sub_dicts[sub_hostname]
                for d in data:
                    print('    {} -> {} -> {}'.format(
                        pub_hostname, d[0], sub_hostname))
                    if show_nodes:
                        print('        Publisher nodes:')
                        for pub_node_name in d[1]:
                            print('            {}'.format(pub_node_name))
                        print('        Subscriber nodes:')
                        for sub_node_name in d[2]:
                            print('            {}'.format(sub_node_name))
    else:
        host_sub_node_dict = {}
        for topic_name in topic_names:
            if topic_name not in pub_node_dicts:
                continue
            if topic_name not in sub_node_dicts:
                continue
            pub_node_dict = pub_node_dicts[topic_name]
            sub_node_dict = sub_node_dicts[topic_name]

            sub_hostnames = sub_node_dict.keys()
            for sub_hostname in sub_hostnames:
                pub_hostnames = [
                    x for x in pub_node_dict.keys() if x != sub_hostname
                ]
                if len(pub_hostnames) == 0:
                    continue
                if sub_hostname not in host_sub_node_dict:
                    host_sub_node_dict[sub_hostname] = {}
                for pub_hostname in pub_hostnames:
                    pub_node_names = pub_node_dict[pub_hostname]
                    sub_node_names = sub_node_dict[sub_hostname]
                    if pub_hostname in host_sub_node_dict[sub_hostname]:
                        host_sub_node_dict[sub_hostname][pub_hostname].append((
                            topic_name,
                            pub_node_names,
                            sub_node_names,
                        ))
                    else:
                        host_sub_node_dict[sub_hostname][pub_hostname] = [(
                            topic_name,
                            pub_node_names,
                            sub_node_names,
                        )]

        sub_hostnames = sorted(host_sub_node_dict.keys())
        for sub_hostname in sub_hostnames:
            if (subscriber_host is not None
                    and sub_hostname != subscriber_host):
                continue
            print('Subscriber host: {}'.format(sub_hostname))
            pub_dicts = host_sub_node_dict[sub_hostname]
            pub_hostnames = sorted(pub_dicts.keys())
            for pub_hostname in pub_hostnames:
                if (publisher_host is not None
                        and pub_hostname != publisher_host):
                    continue
                data = pub_dicts[pub_hostname]
                for d in data:
                    print('    {} -> {} -> {}'.format(
                        pub_hostname, d[0], sub_hostname))
                    if show_nodes:
                        print('        Publisher nodes:')
                        for pub_node_name in d[1]:
                            print('            {}'.format(pub_node_name))
                        print('        Subscriber nodes:')
                        for sub_node_name in d[2]:
                            print('            {}'.format(sub_node_name))


if __name__ == '__main__':
    main()
    sys.exit(0)
