#! /usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import cStringIO
import imp
import inspect
import os
import re
import sys

from telemetry.core import command_line
from telemetry.util import path


# All folders dependent on Telemetry, found using a code search.
BASE_DIRS = (
  path.GetTelemetryDir(),
  os.path.join(path.GetChromiumSrcDir(), 'chrome', 'test', 'telemetry'),
  os.path.join(path.GetChromiumSrcDir(), 'content', 'test', 'gpu'),
  os.path.join(path.GetChromiumSrcDir(), 'tools', 'bisect-manual-test.py'),
  os.path.join(path.GetChromiumSrcDir(), 'tools', 'chrome_proxy'),
  os.path.join(path.GetChromiumSrcDir(), 'tools', 'perf'),
  os.path.join(path.GetChromiumSrcDir(),
               'tools', 'profile_chrome', 'perf_controller.py'),
  os.path.join(path.GetChromiumSrcDir(), 'tools', 'run-bisect-manual-test.py'),
  os.path.join(path.GetChromiumSrcDir(),
               'third_party', 'skia', 'tools', 'skp', 'page_sets'),
  os.path.join(path.GetChromiumSrcDir(), 'third_party', 'trace-viewer'),
)


def SortImportGroups(module_path):
  """Sort each group of imports in the given Python module.

  A group is a collection of adjacent import statements, with no non-import
  lines in between. Groups are sorted according to the Google Python Style
  Guide: "lexicographically, ignoring case, according to each module's full
  package path."
  """
  _TransformImportGroups(module_path, _SortImportGroup)


def _SortImportGroup(import_group):
  def _ImportComparator(import1, import2):
    _, root1, module1, _, _ = import1
    _, root2, module2, _, _ = import2
    full_module1 = (root1 + '.' + module1 if root1 else module1).lower()
    full_module2 = (root2 + '.' + module2 if root2 else module2).lower()
    return cmp(full_module1, full_module2)
  return sorted(import_group, cmp=_ImportComparator)


def _TransformImportGroups(module_path, transformation):
  """Apply a transformation to each group of imports in the given module.

  An import is a tuple of (indent, root, module, alias, suffix),
  serialized as <indent>from <root> import <module> as <alias><suffix>.

  Args:
    module_path: The module to apply transformations on.
    transformation: A function that takes in an import group and returns a
        modified import group. An import group is a list of import tuples.

  Returns:
    True iff the module was modified, and False otherwise.
  """
  def _WriteImports(output_stream, import_group):
    for indent, root, module, alias, suffix in transformation(import_group):
      output_stream.write(indent)
      if root:
        output_stream.write('from ')
        output_stream.write(root)
        output_stream.write(' ')
      output_stream.write('import ')
      output_stream.write(module)
      if alias:
        output_stream.write(' as ')
        output_stream.write(alias)
      output_stream.write(suffix)
      output_stream.write('\n')

  # Read the file so we can diff it later to determine if we made any changes.
  with open(module_path, 'r') as module_file:
    original_file = module_file.read()

  # Locate imports using regex, group them, and transform each one.
  # This regex produces a tuple of (indent, root, module, alias, suffix).
  regex = (r'(\s*)(?:from ((?:[a-z0-9_]+\.)*[a-z0-9_]+) )?'
      r'import ((?:[a-z0-9_]+\.)*[A-Za-z0-9_]+)(?: as ([A-Za-z0-9_]+))?(.*)')
  pattern = re.compile(regex)

  updated_file = cStringIO.StringIO()
  with open(module_path, 'r') as module_file:
    import_group = []
    for line in module_file:
      import_match = pattern.match(line)
      if import_match:
        import_group.append(list(import_match.groups()))
        continue

      if not import_group:
        updated_file.write(line)
        continue

      _WriteImports(updated_file, import_group)
      import_group = []

      updated_file.write(line)

  if import_group:
    _WriteImports(updated_file, import_group)
    import_group = []

  if original_file == updated_file.getvalue():
    return False

  with open(module_path, 'w') as module_file:
    module_file.write(updated_file.getvalue())
  return True


def _ListFiles(base_directory, should_include_dir, should_include_file):
  matching_files = []
  for root, dirs, files in os.walk(base_directory):
    dirs[:] = [dir_name for dir_name in dirs if should_include_dir(dir_name)]
    matching_files += [os.path.join(root, file_name)
                       for file_name in files if should_include_file(file_name)]
  return sorted(matching_files)


class Count(command_line.Command):
  """Print the number of public modules."""

  def Run(self, args):
    modules = _ListFiles(path.GetTelemetryDir(),
                         self._IsPublicApiDir, self._IsPublicApiFile)
    print len(modules)
    return 0

  @staticmethod
  def _IsPublicApiDir(dir_name):
    return (dir_name[0] != '.' and dir_name[0] != '_' and
        not dir_name.startswith('internal') and not dir_name == 'third_party')

  @staticmethod
  def _IsPublicApiFile(file_name):
    root, ext = os.path.splitext(file_name)
    return (file_name[0] != '.' and
        not root.endswith('_unittest') and ext == '.py')


class Mv(command_line.Command):
  """Move modules or packages."""

  @classmethod
  def AddCommandLineArgs(cls, parser):
    parser.add_argument('source', nargs='+')
    parser.add_argument('destination')

  @classmethod
  def ProcessCommandLineArgs(cls, parser, args):
    for source in args.source:
      # Ensure source path exists.
      if not os.path.exists(source):
        parser.error('"%s" not found.' % source)

      # Ensure source path is in one of the BASE_DIRS.
      for base_dir in BASE_DIRS:
        if path.IsSubpath(source, base_dir):
          break
      else:
        parser.error('Source path "%s" is not in any of the base dirs.')

    # Ensure destination path exists.
    if not os.path.exists(args.destination):
      parser.error('"%s" not found.' % args.destination)

    # Ensure destination path is in one of the BASE_DIRS.
    for base_dir in BASE_DIRS:
      if path.IsSubpath(args.destination, base_dir):
        break
    else:
      parser.error('Destination path "%s" is not in any of the base dirs.')

    # If there are multiple source paths, ensure destination is a directory.
    if len(args.source) > 1 and not os.path.isdir(args.destination):
      parser.error('Target "%s" is not a directory.' % args.destination)

    # Ensure destination is not in any of the source paths.
    for source in args.source:
      if path.IsSubpath(args.destination, source):
        parser.error('Cannot move "%s" to a subdirectory of itself, "%s".' %
            (source, args.destination))

  def Run(self, args):
    for dest_base_dir in BASE_DIRS:
      if path.IsSubpath(args.destination, dest_base_dir):
        break

    # Get a list of old and new module names for renaming imports.
    moved_modules = {}
    for source in args.source:
      for source_base_dir in BASE_DIRS:
        if path.IsSubpath(source, source_base_dir):
          break

      source_dir = os.path.dirname(os.path.normpath(source))

      if os.path.isdir(source):
        source_files = _ListFiles(source,
                                  self._IsSourceDir, self._IsPythonModule)
      else:
        source_files = (source,)

      for source_file_path in source_files:
        source_rel_path = os.path.relpath(source_file_path, source_base_dir)
        source_module_name = os.path.splitext(
            source_rel_path)[0].replace(os.sep, '.')

        source_tree = os.path.relpath(source_file_path, source_dir)
        dest_path = os.path.join(args.destination, source_tree)
        dest_rel_path = os.path.relpath(dest_path, dest_base_dir)
        dest_module_name = os.path.splitext(
            dest_rel_path)[0].replace(os.sep, '.')

        moved_modules[source_module_name] = dest_module_name

    # Move things!
    if os.path.isdir(args.destination):
      for source in args.source:
        destination_path = os.path.join(
            args.destination, os.path.split(os.path.normpath(source))[1])
        os.rename(source, destination_path)
    else:
      assert len(args.source) == 1
      os.rename(args.source.pop(), args.destination)

    # Update imports!
    def _UpdateImportGroup(import_group):
      modified = False
      for import_line in import_group:
        _, root, module, _, _ = import_line
        full_module = root + '.' + module if root else module

        if full_module not in moved_modules:
          continue

        modified = True

        # Update import line.
        new_root, _, new_module = moved_modules[full_module].rpartition('.')
        import_line[1] = new_root
        import_line[2] = new_module

      if modified:
        return _SortImportGroup(import_group)
      else:
        return import_group

    for base_dir in BASE_DIRS:
      for module_path in _ListFiles(base_dir,
                                    self._IsSourceDir, self._IsPythonModule):
        if not _TransformImportGroups(module_path, _UpdateImportGroup):
          continue

        # TODO(dtu): Update occurrences.

    print moved_modules

    return 0

  @staticmethod
  def _IsSourceDir(dir_name):
    return dir_name[0] != '.' and dir_name != 'third_party'

  @staticmethod
  def _IsPythonModule(file_name):
    _, ext = os.path.splitext(file_name)
    return ext == '.py'


class RefactorCommand(command_line.SubcommandCommand):
  commands = (Count, Mv,)


if __name__ == '__main__':
  sys.exit(RefactorCommand.main())
