#!/usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import atexit
import json
import logging
import os
import random
import re
import signal
import socket
import subprocess
import sys
import time
import urlparse

# TODO(eseidel): This should be BIN_DIR.
PACKAGES_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
SKY_ENGINE_DIR = os.path.join(PACKAGES_DIR, 'sky_engine')
APK_DIR = os.path.join(os.path.realpath(SKY_ENGINE_DIR), os.pardir, 'apks')

SKY_SERVER_PORT = 9888
OBSERVATORY_PORT = 8181
ADB_PATH = 'adb'
APK_NAME = 'SkyShell.apk'
ANDROID_PACKAGE = "org.domokit.sky.shell"
ANDROID_COMPONENT = '%s/%s.SkyActivity' % (ANDROID_PACKAGE, ANDROID_PACKAGE)

# FIXME: Do we need to look in $DART_SDK?
DART_PATH = 'dart'
PUB_PATH = 'pub'

PID_FILE_PATH = "/tmp/sky_tool.pids"
PID_FILE_KEYS = frozenset([
    'remote_sky_server_port',
    'sky_server_pid',
    'sky_server_port',
    'sky_server_root',
])


def _port_in_use(port):
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    return sock.connect_ex(('localhost', port)) == 0


def _start_http_server(port, root):
    server_command = [
        PUB_PATH, 'run', 'sky_tools:sky_server', str(port),
    ]
    return subprocess.Popen(server_command, cwd=root).pid


# This 'strict dictionary' approach is useful for catching typos.
class Pids(object):
    def __init__(self, known_keys, contents=None):
        self._known_keys = known_keys
        self._dict = contents if contents is not None else {}

    def __len__(self):
        return len(self._dict)

    def get(self, key, default=None):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict.get(key, default)

    def __getitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict[key]

    def __setitem__(self, key, value):
        assert key in self._known_keys, '%s not in known_keys' % key
        self._dict[key] = value

    def __delitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        del self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __contains__(self, key):
        assert key in self._known_keys, '%s not in allowed_keys' % key
        return key in self._dict

    def clear(self):
        self._dict = {}

    def pop(self, key, default=None):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict.pop(key, default)

    @classmethod
    def read_from(cls, path, known_keys):
        contents = {}
        try:
            with open(path, 'r') as pid_file:
                contents = json.load(pid_file)
        except:
            if os.path.exists(path):
                logging.warn('Failed to read pid file: %s' % path)
        return cls(known_keys, contents)

    def write_to(self, path):
        # These keys are required to write a valid file.
        if not self._dict.viewkeys() >= { 'sky_server_pid', 'sky_server_port' }:
            return

        try:
            with open(path, 'w') as pid_file:
                json.dump(self._dict, pid_file, indent=2, sort_keys=True)
        except:
            logging.warn('Failed to write pid file: %s' % path)


def _url_for_path(port, root, path):
    relative_path = os.path.relpath(path, root)
    return 'http://localhost:%s/%s' % (port, relative_path)


class StartSky(object):
    def add_subparser(self, subparsers):
        start_parser = subparsers.add_parser('start',
            help='launch %s on the device' % APK_NAME)
        start_parser.add_argument('--install', action='store_true')
        start_parser.add_argument('--poke', action='store_true')
        start_parser.add_argument('--checked', action='store_true')
        start_parser.add_argument('--build-path', type=str)
        start_parser.add_argument('project_or_path', nargs='?', type=str,
            default='.')
        start_parser.set_defaults(func=self.run)

    def _is_package_installed(self, package_name):
        pm_path_cmd = [ADB_PATH, 'shell', 'pm', 'path', package_name]
        return subprocess.check_output(pm_path_cmd).strip() != ''

    def _is_valid_script_path(self):
        script_path = os.path.dirname(os.path.abspath(__file__))
        script_dirs = script_path.split('/')
        return len(script_dirs) > 1 and script_dirs[-2] == 'packages'

    def run(self, args, pids):
        if not args.poke:
            StopSky().run(args, pids)

        project_or_path = os.path.abspath(args.project_or_path)

        if os.path.isdir(project_or_path):
            sky_server_root = project_or_path
            main_dart = os.path.join(project_or_path, 'lib', 'main.dart')
            missing_msg = "Missing lib/main.dart in project: %s" % project_or_path
        else:
            # FIXME: This assumes the path is at the root of the project!
            # Instead we should walk up looking for a pubspec.yaml
            sky_server_root = os.path.dirname(project_or_path)
            main_dart = project_or_path
            missing_msg = "%s does not exist." % main_dart

        if not os.path.isfile(main_dart):
            logging.error(missing_msg)
            return 2

        package_root = os.path.join(sky_server_root, 'packages')
        if not os.path.isdir(package_root):
            logging.error("%s is not a valid packages path." % package_root)
            return 2

        if not self._is_package_installed(ANDROID_PACKAGE):
            logging.info('%s is not on the device. Installing now...' % APK_NAME)
            args.install = True

        if args.install:
            if not self._is_valid_script_path():
                logging.error("'%s' must be located in packages/sky. " \
                              "The directory packages/sky_engine must also " \
                              "exist to locate %s." \
                              % (os.path.basename(__file__), APK_NAME))
                return 2
            if args.build_path is not None:
                apk_path = os.path.join(args.build_path, 'apks', APK_NAME)
            else:
                apk_path = os.path.join(APK_DIR, APK_NAME)
            if not os.path.exists(apk_path):
                logging.error("'%s' does not exist?" % apk_path)
                return 2

            subprocess.check_call([ADB_PATH, 'install', '-r', apk_path])

        # Set up port forwarding for observatory
        observatory_port_string = 'tcp:%s' % OBSERVATORY_PORT
        subprocess.check_call([
            ADB_PATH, 'forward', observatory_port_string, observatory_port_string
        ])

        sky_server_port = SKY_SERVER_PORT
        pids['sky_server_port'] = sky_server_port
        if _port_in_use(sky_server_port):
            logging.info(('Port %s already in use. '
            ' Not starting server for %s') % (sky_server_port, sky_server_root))
        else:
            sky_server_pid = _start_http_server(sky_server_port, sky_server_root)
            pids['sky_server_pid'] = sky_server_pid
            pids['sky_server_root'] = sky_server_root

        port_string = 'tcp:%s' % sky_server_port
        subprocess.check_call([
            ADB_PATH, 'reverse', port_string, port_string
        ])
        pids['remote_sky_server_port'] = sky_server_port

        # The load happens on the remote device, use the remote port.
        url = _url_for_path(pids['remote_sky_server_port'], sky_server_root,
            main_dart)
        if args.poke:
            url += '?rand=%s' % random.random()

        cmd = [
            ADB_PATH, 'shell',
            'am', 'start',
            '-a', 'android.intent.action.VIEW',
            '-d', url,
        ]

        if args.checked:
            cmd += [ '--ez', 'enable-checked-mode', 'true' ]

        cmd += [ ANDROID_COMPONENT ]

        subprocess.check_call(cmd)


class StopSky(object):
    def add_subparser(self, subparsers):
        stop_parser = subparsers.add_parser('stop',
            help=('kill all running SkyShell.apk processes'))
        stop_parser.set_defaults(func=self.run)

    def _kill_if_exists(self, pids, key, name):
        pid = pids.pop(key, None)
        if not pid:
            logging.debug('No pid for %s, nothing to do.' % name)
            return
        logging.debug('Killing %s (%d).' % (name, pid))
        try:
            os.kill(pid, signal.SIGTERM)
        except OSError:
            logging.debug('%s (%d) already gone.' % (name, pid))

    def run(self, args, pids):
        self._kill_if_exists(pids, 'sky_server_pid', 'sky_server')

        if 'remote_sky_server_port' in pids:
            port_string = 'tcp:%s' % pids['remote_sky_server_port']
            subprocess.call([ADB_PATH, 'reverse', '--remove', port_string])

        subprocess.call([
            ADB_PATH, 'shell', 'am', 'force-stop', ANDROID_PACKAGE])

        pids.clear()


class StartTracing(object):
    def add_subparser(self, subparsers):
        start_tracing_parser = subparsers.add_parser('start_tracing',
            help=('start tracing a running sky instance'))
        start_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
            '-a', 'org.domokit.sky.shell.TRACING_START'])


TRACE_COMPLETE_REGEXP = re.compile('Trace complete')
TRACE_FILE_REGEXP = re.compile(r'Saving trace to (?P<path>\S+)')


class StopTracing(object):
    def add_subparser(self, subparsers):
        stop_tracing_parser = subparsers.add_parser('stop_tracing',
            help=('stop tracing a running sky instance'))
        stop_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'logcat', '-c'])
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
            '-a', 'org.domokit.sky.shell.TRACING_STOP'])
        device_path = None
        is_complete = False
        while not is_complete:
            time.sleep(0.2)
            log = subprocess.check_output([ADB_PATH, 'logcat', '-d'])
            if device_path is None:
                result = TRACE_FILE_REGEXP.search(log)
                if result:
                    device_path = result.group('path')
            is_complete = TRACE_COMPLETE_REGEXP.search(log) is not None

        logging.info('Downloading trace %s ...' % os.path.basename(device_path))

        if device_path:
            subprocess.check_output([ADB_PATH, 'pull', device_path])
            subprocess.check_output([ADB_PATH, 'shell', 'rm', device_path])


class SkyShellRunner(object):
    def _update_paths(self):
        global ADB_PATH
        if 'ANDROID_HOME' in os.environ:
            android_home_dir = os.environ['ANDROID_HOME']
            ADB_PATH = os.path.join(android_home_dir, 'sdk/platform-tools/adb')

    def _is_valid_adb_version(self, adb_version):
        # Sample output: "Android Debug Bridge version 1.0.31"
        version_fields = re.search('(\d+)\.(\d+)\.(\d+)', adb_version)
        if version_fields:
            major_version = int(version_fields.group(1))
            minor_version = int(version_fields.group(2))
            patch_version = int(version_fields.group(3))
            if major_version > 1:
                return True
            if major_version == 1 and minor_version > 0:
                return True
            if major_version == 1 and minor_version == 0 and patch_version >= 32:
                return True
            return False
        else:
            logging.warn('Unrecognized adb version string. Skipping version check.')
            return True

    def _check_for_adb(self):
        try:
            adb_version = subprocess.check_output([ADB_PATH, 'version'])
            if self._is_valid_adb_version(adb_version):
                return True

            adb_path = subprocess.check_output( ['which', ADB_PATH]).rstrip()
            logging.error("'%s' is too old. Need 1.0.32 or later. " \
                "Try setting ANDROID_HOME." % adb_path)
            return False

        except OSError:
            logging.error("'adb' (from the Android SDK) not in $PATH, can't continue.")
            return False
        return True

    def _check_for_lollipop_or_later(self):
        try:
            # If the server is automatically restarted, then we get irrelevant
            # output lines like this, which we want to ignore:
            #   adb server is out of date.  killing..
            #   * daemon started successfully *

            subprocess.call([ADB_PATH, 'start-server'])
            sdk_version = subprocess.check_output(
                    [ADB_PATH, 'shell', 'getprop', 'ro.build.version.sdk']).rstrip()
            # Sample output: "22"
            if not sdk_version.isdigit():
                logging.error("Unexpected response from getprop: '%s'." % sdk_version)
                return False
 
            if int(sdk_version) < 22:
                logging.error("Version '%s' of the Android SDK is too old. " \
                              "Need Lollipop (22) or later. " % sdk_version)
                return False

        except subprocess.CalledProcessError as e:
            # adb printed the error, so we print nothing.
            return False
        return True

    def _check_for_dart(self):
        try:
            subprocess.check_output([DART_PATH, '--version'], stderr=subprocess.STDOUT)
        except OSError:
            logging.error("'dart' (from the Dart SDK) not in $PATH, can't continue.")
            return False
        return True

    def main(self):
        logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)

        self._update_paths()
        if not self._check_for_adb() or not self._check_for_lollipop_or_later():
            sys.exit(2)
        if not self._check_for_dart():
            sys.exit(2)

        parser = argparse.ArgumentParser(description='Sky Demo Runner')
        subparsers = parser.add_subparsers(help='sub-command help')

        for command in [StartSky(), StopSky(), StartTracing(), StopTracing()]:
            command.add_subparser(subparsers)

        args = parser.parse_args()
        pids = Pids.read_from(PID_FILE_PATH, PID_FILE_KEYS)
        atexit.register(pids.write_to, PID_FILE_PATH)
        exit_code = 0
        try:
          exit_code = args.func(args, pids)
        except subprocess.CalledProcessError as e:
          # Don't print a stack trace if the adb command fails.
          logging.error(e)
          exit_code = 2
        sys.exit(exit_code)


if __name__ == '__main__':
    SkyShellRunner().main()