#! /usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import cStringIO
import functools
import imp
import inspect
import itertools
import operator
import os
import re
import sys

from catapult_base import refactor
from catapult_base.refactor_util import move

from telemetry.internal.util import command_line
from telemetry.internal.util import path


_RELATIVE_BASE_DIRS = (
  ('chrome', 'test', 'telemetry'),
  ('content', 'test', 'gpu'),
  ('tools', 'bisect-manual-test.py'),
  ('tools', 'chrome_proxy'),
  ('tools', 'perf'),
  ('tools', 'profile_chrome', 'perf_controller.py'),
  ('tools', 'run-bisect-manual-test.py'),
  ('third_party', 'skia', 'tools', 'skp', 'page_sets'),
  ('third_party', 'trace-viewer'),
)
# All folders dependent on Telemetry, found using a code search.
# Note that this is not the same as the directory that imports are relative to.
BASE_DIRS = [path.GetTelemetryDir()] + [
    os.path.join(path.GetChromiumSrcDir(), *dir_path)
    for dir_path in _RELATIVE_BASE_DIRS]


def SortImportGroups(module_path):
  """Sort each group of imports in the given Python module.

  A group is a collection of adjacent import statements, with no non-import
  lines in between. Groups are sorted according to the Google Python Style
  Guide: "lexicographically, ignoring case, according to each module's full
  package path."
  """
  _TransformImportGroups(module_path, _SortImportGroup)


def _SortImportGroup(import_group):
  def _ImportComparator(import1, import2):
    _, root1, module1, _, _ = import1
    _, root2, module2, _, _ = import2
    full_module1 = (root1 + '.' + module1 if root1 else module1).lower()
    full_module2 = (root2 + '.' + module2 if root2 else module2).lower()
    return cmp(full_module1, full_module2)
  return sorted(import_group, cmp=_ImportComparator)


def _TransformImportGroups(module_path, transformation):
  """Apply a transformation to each group of imports in the given module.

  An import is a tuple of (indent, root, module, alias, suffix),
  serialized as <indent>from <root> import <module> as <alias><suffix>.

  Args:
    module_path: The module to apply transformations on.
    transformation: A function that takes in an import group and returns a
        modified import group. An import group is a list of import tuples.

  Returns:
    True iff the module was modified, and False otherwise.
  """
  def _WriteImports(output_stream, import_group):
    for indent, root, module, alias, suffix in transformation(import_group):
      output_stream.write(indent)
      if root:
        output_stream.write('from ')
        output_stream.write(root)
        output_stream.write(' ')
      output_stream.write('import ')
      output_stream.write(module)
      if alias:
        output_stream.write(' as ')
        output_stream.write(alias)
      output_stream.write(suffix)
      output_stream.write('\n')

  # Read the file so we can diff it later to determine if we made any changes.
  with open(module_path, 'r') as module_file:
    original_file = module_file.read()

  # Locate imports using regex, group them, and transform each one.
  # This regex produces a tuple of (indent, root, module, alias, suffix).
  regex = (r'(\s*)(?:from ((?:[a-z0-9_]+\.)*[a-z0-9_]+) )?'
      r'import ((?:[a-z0-9_]+\.)*[A-Za-z0-9_]+)(?: as ([A-Za-z0-9_]+))?(.*)')
  pattern = re.compile(regex)

  updated_file = cStringIO.StringIO()
  with open(module_path, 'r') as module_file:
    import_group = []
    for line in module_file:
      import_match = pattern.match(line)
      if import_match:
        import_group.append(list(import_match.groups()))
        continue

      if not import_group:
        updated_file.write(line)
        continue

      _WriteImports(updated_file, import_group)
      import_group = []

      updated_file.write(line)

  if import_group:
    _WriteImports(updated_file, import_group)
    import_group = []

  if original_file == updated_file.getvalue():
    return False

  with open(module_path, 'w') as module_file:
    module_file.write(updated_file.getvalue())
  return True


def _CountInterfaces(module):
  return (len(list(module.FindChildren(refactor.Class))) +
      len(list(module.FindChildren(refactor.Function))))


def _IsSourceDir(dir_name):
  return dir_name[0] != '.' and dir_name != 'third_party'


def _IsPythonModule(file_name):
  _, ext = os.path.splitext(file_name)
  return ext == '.py'


class Count(command_line.Command):
  """Print the number of public modules."""

  @classmethod
  def AddCommandLineArgs(cls, parser):
    parser.add_argument('type', nargs='?', choices=('interfaces', 'modules'),
                        default='modules')

  def Run(self, args):
    module_paths = path.ListFiles(
        path.GetTelemetryDir(), self._IsPublicApiDir, self._IsPublicApiFile)

    if args.type == 'modules':
      print len(module_paths)
    elif args.type == 'interfaces':
      print reduce(operator.add, refactor.Transform(_CountInterfaces, module_paths))

    return 0

  @staticmethod
  def _IsPublicApiDir(dir_name):
    return (dir_name[0] != '.' and dir_name[0] != '_' and
        dir_name != 'internal' and dir_name != 'third_party')

  @staticmethod
  def _IsPublicApiFile(file_name):
    root, ext = os.path.splitext(file_name)
    return (file_name[0] != '.' and
        not root.endswith('_unittest') and ext == '.py')


def _TelemetryFiles():
  list_files = functools.partial(path.ListFiles,
                                 should_include_dir=_IsSourceDir,
                                 should_include_file=_IsPythonModule)
  return sorted(itertools.chain(*map(list_files, BASE_DIRS)))


class Mv(command_line.Command):
  """Move modules or packages."""

  @classmethod
  def AddCommandLineArgs(cls, parser):
    parser.add_argument('source', nargs='+')
    parser.add_argument('target')

  @classmethod
  def ProcessCommandLineArgs(cls, parser, args):
    # Check source file paths.
    for source_path in args.source:
      # Ensure source path exists.
      if not os.path.exists(source_path):
        parser.error('"%s" not found.' % source_path)

      # Ensure source path is in one of the BASE_DIRS.
      for base_dir in BASE_DIRS:
        if path.IsSubpath(source_path, base_dir):
          break
      else:
        parser.error('"%s" is not in any of the base dirs.')

    # Ensure target directory exists.
    if not (os.path.exists(args.target) or
        os.path.exists(os.path.dirname(args.target))):
      parser.error('"%s" not found.' % args.target)

    # Ensure target path is in one of the BASE_DIRS.
    for base_dir in BASE_DIRS:
      if path.IsSubpath(args.target, base_dir):
        break
    else:
      parser.error('"%s" is not in any of the base dirs.')

    # If there are multiple source paths, ensure target is a directory.
    if len(args.source) > 1 and not os.path.isdir(args.target):
      parser.error('Target "%s" is not a directory.' % args.target)

    # Ensure target is not in any of the source paths.
    for source_path in args.source:
      if path.IsSubpath(args.target, source_path):
        parser.error('Cannot move "%s" to a subdirectory of itself, "%s".' %
            (source_path, args.target))

  def Run(self, args):
    move.Run(args.source, args.target, _TelemetryFiles())
    for module_path in _TelemetryFiles():
      SortImportGroups(module_path)
    return 0


class Sort(command_line.Command):
  """Sort imports."""

  @classmethod
  def AddCommandLineArgs(cls, parser):
    parser.add_argument('target', nargs='*')

  @classmethod
  def ProcessCommandLineArgs(cls, parser, args):
    for target in args.target:
      if not os.path.exists(target):
        parser.error('"%s" not found.' % target)

  def Run(self, args):
    if args.target:
      targets = args.target
    else:
      targets = BASE_DIRS

    for base_dir in targets:
      for module_path in path.ListFiles(base_dir, _IsSourceDir, _IsPythonModule):
        SortImportGroups(module_path)
    return 0


class RefactorCommand(command_line.SubcommandCommand):
  commands = (Count, Mv, Sort,)


if __name__ == '__main__':
  sys.exit(RefactorCommand.main())
