#!/usr/bin/env python3
import argparse
import os
try:
    import rospy
    import_rospy = True
except ImportError:
    import_rospy = False
import subprocess
import sys

CONTAINERS = {"ofa": "jsk-ofa-server",
              "clip": "jsk-clip-server"}
OFA_MODEL_SCALES = ["base", "large", "huge"]

parser = argparse.ArgumentParser(description="JSK Vision and Language API runner")
parser.add_argument("model", choices=CONTAINERS.keys(),
                    help="Vision and Language model you want to run", type=str)
parser.add_argument("-p", "--port", default=8888, help="Port of the API server", type=int)
parser.add_argument("--ofa_task", default="caption",
                    help="Tasks to be loaded in OFA. OFA loads all tasks by default and requires large amount of GPU RAM")
parser.add_argument("--ofa_model_scale", default="huge", choices=OFA_MODEL_SCALES,
                    help="Scale of parameters to be read in OFA. Use huge by default")

args, unknown = parser.parse_known_args()

if __name__ == "__main__":
    if import_rospy and len(sys.argv) != len(rospy.myargv()): # roslaunch
        rospy.init_node("jsk_vil_docker_manager")
    port_option = "{}:8080".format(str(args.port))
    _ros_home = os.getenv("ROS_HOME", os.path.join(os.environ["HOME"],".ros"))
    mount_option = "{}:/var/mount/params".format(os.path.join(_ros_home, "data", "jsk_perception", "vil_params"))
    docker_image = "{}:latest".format(CONTAINERS[args.model])

    cmd = ["docker", "run", "--rm", "-it", "--gpus", "all",
           "-p", port_option,
           "-v", mount_option]

    # OFA
    if args.model == "ofa":
        ofa_option = ["-e", "OFA_TASK={}".format(args.ofa_task),
                      "-e", "OFA_MODEL_SCALE={}".format(args.ofa_model_scale)]
        cmd += ofa_option

    # Execute
    print("Run {} container on port {}".format(docker_image, args.port))
    cmd.append(docker_image)
    subprocess.run(cmd)
