sky_tool 14.2 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
7
import atexit
8 9 10
import json
import logging
import os
11
import random
12 13 14 15 16 17
import re
import signal
import socket
import subprocess
import sys
import time
18
import urlparse
19 20

# TODO(eseidel): This should be BIN_DIR.
21 22 23
PACKAGES_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
SKY_ENGINE_DIR = os.path.join(PACKAGES_DIR, 'sky_engine')
APK_DIR = os.path.join(os.path.realpath(SKY_ENGINE_DIR), os.pardir, 'apks')
24 25 26

SKY_SERVER_PORT = 9888
OBSERVATORY_PORT = 8181
27
ADB_PATH = 'adb'
Adam Barth's avatar
Adam Barth committed
28 29 30
APK_NAME = 'SkyShell.apk'
ANDROID_PACKAGE = "org.domokit.sky.shell"
ANDROID_COMPONENT = '%s/%s.SkyActivity' % (ANDROID_PACKAGE, ANDROID_PACKAGE)
31

32 33
# FIXME: Do we need to look in $DART_SDK?
DART_PATH = 'dart'
34
PUB_PATH = 'pub'
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

PID_FILE_PATH = "/tmp/sky_tool.pids"
PID_FILE_KEYS = frozenset([
    'remote_sky_server_port',
    'sky_server_pid',
    'sky_server_port',
    'sky_server_root',
])


def _port_in_use(port):
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    return sock.connect_ex(('localhost', port)) == 0


def _start_http_server(port, root):
    server_command = [
52
        PUB_PATH, 'run', 'sky_tools:sky_server', str(port),
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    ]
    return subprocess.Popen(server_command, cwd=root).pid


# This 'strict dictionary' approach is useful for catching typos.
class Pids(object):
    def __init__(self, known_keys, contents=None):
        self._known_keys = known_keys
        self._dict = contents if contents is not None else {}

    def __len__(self):
        return len(self._dict)

    def get(self, key, default=None):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict.get(key, default)

    def __getitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict[key]

    def __setitem__(self, key, value):
        assert key in self._known_keys, '%s not in known_keys' % key
        self._dict[key] = value

    def __delitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        del self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __contains__(self, key):
        assert key in self._known_keys, '%s not in allowed_keys' % key
        return key in self._dict

    def clear(self):
        self._dict = {}

    def pop(self, key, default=None):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict.pop(key, default)

    @classmethod
    def read_from(cls, path, known_keys):
        contents = {}
        try:
            with open(path, 'r') as pid_file:
                contents = json.load(pid_file)
        except:
            if os.path.exists(path):
                logging.warn('Failed to read pid file: %s' % path)
        return cls(known_keys, contents)

    def write_to(self, path):
108 109 110 111
        # These keys are required to write a valid file.
        if not self._dict.viewkeys() >= { 'sky_server_pid', 'sky_server_port' }:
            return

112 113 114 115 116 117 118 119 120
        try:
            with open(path, 'w') as pid_file:
                json.dump(self._dict, pid_file, indent=2, sort_keys=True)
        except:
            logging.warn('Failed to write pid file: %s' % path)


def _url_for_path(port, root, path):
    relative_path = os.path.relpath(path, root)
121
    return 'http://localhost:%s/%s' % (port, relative_path)
122 123 124 125 126 127 128


class StartSky(object):
    def add_subparser(self, subparsers):
        start_parser = subparsers.add_parser('start',
            help='launch %s on the device' % APK_NAME)
        start_parser.add_argument('--install', action='store_true')
129
        start_parser.add_argument('--poke', action='store_true')
130
        start_parser.add_argument('--checked', action='store_true')
131
        start_parser.add_argument('--build-path', type=str)
132 133 134 135 136 137 138 139
        start_parser.add_argument('project_or_path', nargs='?', type=str,
            default='.')
        start_parser.set_defaults(func=self.run)

    def _is_package_installed(self, package_name):
        pm_path_cmd = [ADB_PATH, 'shell', 'pm', 'path', package_name]
        return subprocess.check_output(pm_path_cmd).strip() != ''

140 141 142 143 144
    def _is_valid_script_path(self):
        script_path = os.path.dirname(os.path.abspath(__file__))
        script_dirs = script_path.split('/')
        return len(script_dirs) > 1 and script_dirs[-2] == 'packages'

145
    def run(self, args, pids):
146 147
        if not args.poke:
            StopSky().run(args, pids)
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162

        project_or_path = os.path.abspath(args.project_or_path)

        if os.path.isdir(project_or_path):
            sky_server_root = project_or_path
            main_dart = os.path.join(project_or_path, 'lib', 'main.dart')
            missing_msg = "Missing lib/main.dart in project: %s" % project_or_path
        else:
            # FIXME: This assumes the path is at the root of the project!
            # Instead we should walk up looking for a pubspec.yaml
            sky_server_root = os.path.dirname(project_or_path)
            main_dart = project_or_path
            missing_msg = "%s does not exist." % main_dart

        if not os.path.isfile(main_dart):
163
            logging.error(missing_msg)
164 165 166 167
            return 2

        package_root = os.path.join(sky_server_root, 'packages')
        if not os.path.isdir(package_root):
168
            logging.error("%s is not a valid packages path." % package_root)
169 170 171
            return 2

        if not self._is_package_installed(ANDROID_PACKAGE):
172
            logging.info('%s is not on the device. Installing now...' % APK_NAME)
173 174 175
            args.install = True

        if args.install:
176 177 178 179 180 181
            if not self._is_valid_script_path():
                logging.error("'%s' must be located in packages/sky. " \
                              "The directory packages/sky_engine must also " \
                              "exist to locate %s." \
                              % (os.path.basename(__file__), APK_NAME))
                return 2
182 183 184 185
            if args.build_path is not None:
                apk_path = os.path.join(args.build_path, 'apks', APK_NAME)
            else:
                apk_path = os.path.join(APK_DIR, APK_NAME)
186
            if not os.path.exists(apk_path):
187
                logging.error("'%s' does not exist?" % apk_path)
188 189 190 191 192 193 194 195 196 197 198 199 200
                return 2

            subprocess.check_call([ADB_PATH, 'install', '-r', apk_path])

        # Set up port forwarding for observatory
        observatory_port_string = 'tcp:%s' % OBSERVATORY_PORT
        subprocess.check_call([
            ADB_PATH, 'forward', observatory_port_string, observatory_port_string
        ])

        sky_server_port = SKY_SERVER_PORT
        pids['sky_server_port'] = sky_server_port
        if _port_in_use(sky_server_port):
201
            logging.info(('Port %s already in use. '
202 203 204 205 206 207 208 209 210 211 212 213 214
            ' Not starting server for %s') % (sky_server_port, sky_server_root))
        else:
            sky_server_pid = _start_http_server(sky_server_port, sky_server_root)
            pids['sky_server_pid'] = sky_server_pid
            pids['sky_server_root'] = sky_server_root

        port_string = 'tcp:%s' % sky_server_port
        subprocess.check_call([
            ADB_PATH, 'reverse', port_string, port_string
        ])
        pids['remote_sky_server_port'] = sky_server_port

        # The load happens on the remote device, use the remote port.
215
        url = _url_for_path(pids['remote_sky_server_port'], sky_server_root,
216
            main_dart)
217 218
        if args.poke:
            url += '?rand=%s' % random.random()
219

220 221
        cmd = [
            ADB_PATH, 'shell',
222 223
            'am', 'start',
            '-a', 'android.intent.action.VIEW',
224
            '-d', url,
225 226 227 228 229 230
        ]

        if args.checked:
            cmd += [ '--ez', 'enable-checked-mode', 'true' ]

        cmd += [ ANDROID_COMPONENT ]
231

232
        subprocess.check_output(cmd)
233 234 235 236 237 238 239 240


class StopSky(object):
    def add_subparser(self, subparsers):
        stop_parser = subparsers.add_parser('stop',
            help=('kill all running SkyShell.apk processes'))
        stop_parser.set_defaults(func=self.run)

241 242 243
    def _run(self, args):
        with open('/dev/null', 'w') as dev_null:
            subprocess.call(args, stdout=dev_null, stderr=dev_null)
244 245

    def run(self, args, pids):
246
        self._run(['fuser', '-k', '%s/tcp' % SKY_SERVER_PORT])
247 248 249

        if 'remote_sky_server_port' in pids:
            port_string = 'tcp:%s' % pids['remote_sky_server_port']
250
            self._run([ADB_PATH, 'reverse', '--remove', port_string])
251

252
        self._run([ADB_PATH, 'shell', 'am', 'force-stop', ANDROID_PACKAGE])
253 254 255 256 257 258 259 260 261 262 263 264 265

        pids.clear()


class StartTracing(object):
    def add_subparser(self, subparsers):
        start_tracing_parser = subparsers.add_parser('start_tracing',
            help=('start tracing a running sky instance'))
        start_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
Adam Barth's avatar
Adam Barth committed
266
            '-a', 'org.domokit.sky.shell.TRACING_START'])
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282


TRACE_COMPLETE_REGEXP = re.compile('Trace complete')
TRACE_FILE_REGEXP = re.compile(r'Saving trace to (?P<path>\S+)')


class StopTracing(object):
    def add_subparser(self, subparsers):
        stop_tracing_parser = subparsers.add_parser('stop_tracing',
            help=('stop tracing a running sky instance'))
        stop_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'logcat', '-c'])
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
Adam Barth's avatar
Adam Barth committed
283
            '-a', 'org.domokit.sky.shell.TRACING_STOP'])
284 285 286 287 288 289 290 291 292 293 294
        device_path = None
        is_complete = False
        while not is_complete:
            time.sleep(0.2)
            log = subprocess.check_output([ADB_PATH, 'logcat', '-d'])
            if device_path is None:
                result = TRACE_FILE_REGEXP.search(log)
                if result:
                    device_path = result.group('path')
            is_complete = TRACE_COMPLETE_REGEXP.search(log) is not None

295
        logging.info('Downloading trace %s ...' % os.path.basename(device_path))
296 297 298 299 300 301 302

        if device_path:
            subprocess.check_output([ADB_PATH, 'pull', device_path])
            subprocess.check_output([ADB_PATH, 'shell', 'rm', device_path])


class SkyShellRunner(object):
303 304 305 306 307 308
    def _update_paths(self):
        global ADB_PATH
        if 'ANDROID_HOME' in os.environ:
            android_home_dir = os.environ['ANDROID_HOME']
            ADB_PATH = os.path.join(android_home_dir, 'sdk/platform-tools/adb')

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
    def _is_valid_adb_version(self, adb_version):
        # Sample output: "Android Debug Bridge version 1.0.31"
        version_fields = re.search('(\d+)\.(\d+)\.(\d+)', adb_version)
        if version_fields:
            major_version = int(version_fields.group(1))
            minor_version = int(version_fields.group(2))
            patch_version = int(version_fields.group(3))
            if major_version > 1:
                return True
            if major_version == 1 and minor_version > 0:
                return True
            if major_version == 1 and minor_version == 0 and patch_version >= 32:
                return True
            return False
        else:
            logging.warn('Unrecognized adb version string. Skipping version check.')
            return True

327 328
    def _check_for_adb(self):
        try:
329
            adb_version = subprocess.check_output([ADB_PATH, 'version'])
330 331 332 333 334 335 336
            if self._is_valid_adb_version(adb_version):
                return True

            adb_path = subprocess.check_output( ['which', ADB_PATH]).rstrip()
            logging.error("'%s' is too old. Need 1.0.32 or later. " \
                "Try setting ANDROID_HOME." % adb_path)
            return False
337

338
        except OSError:
339 340 341 342 343 344 345 346
            logging.error("'adb' (from the Android SDK) not in $PATH, can't continue.")
            return False
        return True

    def _check_for_lollipop_or_later(self):
        try:
            # If the server is automatically restarted, then we get irrelevant
            # output lines like this, which we want to ignore:
347
            #   adb server is out of date.  killing..
348 349 350
            #   * daemon started successfully *

            subprocess.call([ADB_PATH, 'start-server'])
351 352
            sdk_version = subprocess.check_output(
                    [ADB_PATH, 'shell', 'getprop', 'ro.build.version.sdk']).rstrip()
353 354 355 356
            # Sample output: "22"
            if not sdk_version.isdigit():
                logging.error("Unexpected response from getprop: '%s'." % sdk_version)
                return False
357

358 359 360 361 362 363 364
            if int(sdk_version) < 22:
                logging.error("Version '%s' of the Android SDK is too old. " \
                              "Need Lollipop (22) or later. " % sdk_version)
                return False

        except subprocess.CalledProcessError as e:
            # adb printed the error, so we print nothing.
365 366 367 368 369
            return False
        return True

    def _check_for_dart(self):
        try:
370
            subprocess.check_output([DART_PATH, '--version'], stderr=subprocess.STDOUT)
371
        except OSError:
372
            logging.error("'dart' (from the Dart SDK) not in $PATH, can't continue.")
373 374 375 376
            return False
        return True

    def main(self):
377 378 379 380 381 382
        logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)

        self._update_paths()
        if not self._check_for_adb() or not self._check_for_lollipop_or_later():
            sys.exit(2)
        if not self._check_for_dart():
383 384 385 386 387 388 389 390 391 392
            sys.exit(2)

        parser = argparse.ArgumentParser(description='Sky Demo Runner')
        subparsers = parser.add_subparsers(help='sub-command help')

        for command in [StartSky(), StopSky(), StartTracing(), StopTracing()]:
            command.add_subparser(subparsers)

        args = parser.parse_args()
        pids = Pids.read_from(PID_FILE_PATH, PID_FILE_KEYS)
393 394 395 396 397 398
        atexit.register(pids.write_to, PID_FILE_PATH)
        exit_code = 0
        try:
          exit_code = args.func(args, pids)
        except subprocess.CalledProcessError as e:
          # Don't print a stack trace if the adb command fails.
399
          logging.error(e)
400
          exit_code = 2
401 402 403 404 405
        sys.exit(exit_code)


if __name__ == '__main__':
    SkyShellRunner().main()