| # Copyright 2021 The Chromium OS Authors. All rights reserved. |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| """Provides help with the conflict resolution feature of forklift""" |
| |
| import re |
| |
| class Conflict: |
| """Represents a conflict located in a local file. |
| |
| Attributes: |
| sha: The upstream commit sha involved in the conflict. |
| subject: The upstream commit subject. |
| head_confict: List of the head lines involved in the conflict. |
| remote_conflict: The lines from the remote which are conflicting. |
| """ |
| def __init__(self, head=None, separator=None, remote=None): |
| """Initialize the conflict object. |
| |
| Args: |
| head: The line number of the '<<<<<<< HEAD' sentinel. |
| separator: The line number of the '=======' sentinel. |
| remote: The line number of the '>>>>>>> <sha>...<subject' sentinel. |
| """ |
| self._head = None |
| self._separator = None |
| self._remote = None |
| |
| self.sha = None |
| self.subject = None |
| self.head_conflict = [] |
| self.remote_conflict = [] |
| |
| self._set_head(head) |
| self._set_separator(separator) |
| self._set_remote(remote) |
| |
| @staticmethod |
| def _valid_conflict(head, separator, remote): |
| valid = True |
| if separator: |
| valid = valid and head and head < separator |
| if remote: |
| valid = valid and separator and separator < remote |
| return valid |
| |
| def _set_head(self, head): |
| if not self._valid_conflict(head, self._separator, self._remote): |
| raise ValueError((f'Conflict {head}/{self._separator}/' |
| f'{self._remote} invalid.')) |
| self._head = head |
| |
| def _set_separator(self, separator): |
| if not self._valid_conflict(self._head, separator, self._remote): |
| raise ValueError((f'Conflict {self._head}/{separator}/' |
| f'{self._remote} invalid.')) |
| self._separator = separator |
| |
| def _set_remote(self, remote): |
| if not self._valid_conflict(self._head, self._separator, remote): |
| raise ValueError((f'Conflict{self._head}/{self._separator}/' |
| f'{remote} invalid.')) |
| self._remote = remote |
| |
| def head(self): |
| """Returns the line number of '<<<<<<< HEAD' for the conflict.""" |
| return self._head |
| |
| def separator(self): |
| """Returns the line number of '=======' for the conflict.""" |
| return self._separator |
| |
| def remote(self): |
| """Returns the line number of '>>>>>>>' for the conflict.""" |
| return self._remote |
| |
| def parse(self, line_num, line): |
| """Parses the line and adds it to the internal state if applicable. |
| |
| Args: |
| line_num: The number of the line being parsed. |
| line: The contents of the current line being parsed. |
| |
| Returns: |
| True if the conflict has been completely parsed, False otherwise. |
| """ |
| if line.startswith('<<<<<<<'): |
| self._set_head(line_num) |
| elif line.startswith('======='): |
| self._set_separator(line_num) |
| elif line.startswith('>>>>>>>'): |
| self._set_remote(line_num) |
| m = re.match(r'>>>>>>> ([a-f0-9]+)(\.{3})? \(?(.+)\)?\n', line) |
| self.sha = m.group(1) |
| self.subject = m.group(3) |
| return True |
| elif self._head and not self._separator: |
| self.head_conflict.append(line.rstrip()) |
| elif self._separator and not self._remote: |
| self.remote_conflict.append(line.rstrip()) |
| |
| return False |
| |
| class Resolver: |
| """Class to assist in resolving a conflict.""" |
| def __init__(self, git, path): |
| """Initialize the Resolver class. |
| |
| Args: |
| git: The Git object to use for git operations. |
| path: The path of the file containing the conflicts to resolve. |
| """ |
| self._git = git |
| self._path = path |
| |
| def get_conflicts(self): |
| """Returns a list of conflicts from the file at the given path. |
| |
| Parses the file at self._path and pulls out all the conflicts |
| into Conflict objects. Returns a list of those conflicts. |
| |
| Args: |
| path: The path to the conflicting file. |
| |
| Returns: |
| A list of conflicts from the given file. |
| """ |
| conflicts = [] |
| with open(self._path, mode='r') as f: |
| cur_conflict = Conflict() |
| line_num = 1 |
| for l in f: |
| if cur_conflict.parse(line_num, l): |
| conflicts.append(cur_conflict) |
| cur_conflict = Conflict() |
| line_num += 1 |
| |
| return conflicts |
| |
| @staticmethod |
| def _format_conflict_line(line_num, line): |
| return f'{line_num:<5} {line}\n' |
| |
| def format_conflict(self, conflict, print_head=True, print_remote=True): |
| """Formats the conflict in a human-readable format. |
| |
| Args: |
| conflict: The Conflict object to format. |
| print_head: True if output should contain the HEAD portion. |
| print_remote: True if output should contain the remote portion. |
| |
| Returns: |
| The formatted conflict in a string. |
| """ |
| ret = '' |
| if print_head: |
| ret += self._format_conflict_line(conflict.head(), '<<<<<<< HEAD') |
| |
| for i, l in enumerate(conflict.head_conflict): |
| ret += self._format_conflict_line(i + 1 + conflict.head(), l) |
| |
| ret += self._format_conflict_line(conflict.separator(), '=======') |
| |
| if print_remote: |
| for i, l in enumerate(conflict.remote_conflict): |
| ret += self._format_conflict_line(i + 1 + conflict.separator(), |
| l) |
| |
| ret += self._format_conflict_line(conflict.remote(), |
| f'>>>>>>> {conflict.sha}.. {conflict.subject}') |
| |
| return ret |
| |
| def blame_head(self, conflict): |
| """Fetch the git blame output for the HEAD portion of the conflict. |
| |
| Args: |
| conflict: The Conflict object to assign blame. |
| |
| Returns: |
| String containing the git blame output for the conflict's HEAD text. |
| """ |
| blame = self._git.blame(self._path).splitlines() |
| |
| # 3 lines of context on either side |
| start = max(0, conflict.head() - 4) |
| end = conflict.head() - 1 |
| ret = '\n'.join(blame[start:end]) |
| ret += '\n' |
| |
| start = conflict.head() |
| end = conflict.separator() - 1 |
| ret += '\n'.join(blame[start:end]) |
| ret += '\n' |
| |
| start = min(len(blame), conflict.remote()) |
| end = min(len(blame), conflict.remote() + 3) |
| ret += '\n'.join(blame[start:end]) |
| |
| return ret |
| |
| @staticmethod |
| def _get_diff_chunks(diff): |
| re_chunk = re.compile((r'@@ -([0-9]+),([0-9]+) \+([0-9]+),([0-9]+) @@' |
| '(.*)?')) |
| |
| chunks = [] |
| cur_chunk = None |
| for l in diff.splitlines(): |
| m = re_chunk.match(l) |
| if m: |
| if cur_chunk: |
| chunks.append(cur_chunk) |
| |
| cur_chunk = {'old_line': int(m.group(1)), |
| 'old_num': int(m.group(2)), |
| 'new_line': int(m.group(3)), |
| 'new_num': int(m.group(4)), |
| 'identifier': m.group(5) if m.group(5) else 'NA', |
| 'chunk': [], |
| 'score': 0} |
| else: |
| cur_chunk['chunk'].append(l.rstrip()) |
| |
| if cur_chunk: |
| chunks.append(cur_chunk) |
| |
| return chunks |
| |
| def _score_chunks_by_identifier(self, conflict, chunk_list): |
| # Walk through the local file backwards starting at the conflict |
| # looking for the first identifier also showing up in the git diff |
| # output for the conflicting change. Score one point to any chunk with |
| # the same identifier |
| with open(self._path, 'r') as f: |
| lines = f.readlines() |
| |
| identifier = None |
| for l in reversed(lines[:conflict.head()]): |
| for c in chunk_list: |
| if l.rstrip() == c['identifier']: |
| identifier = c['identifier'] |
| break |
| |
| if not identifier: |
| return |
| |
| for c in chunk_list: |
| if c['identifier'] == identifier: |
| c['score'] += 1 |
| |
| @staticmethod |
| def _score_chunks_by_addition(conflict, chunk_list): |
| # Try to find the conflicting code by comparing the remote portion of |
| # the conflict with the added code in each git chunk. The more lines |
| # that match, the better the score. |
| for chunk in chunk_list: |
| for cl in chunk['chunk']: |
| if not cl.startswith('+'): |
| continue |
| for l in conflict.remote_conflict: |
| if l == cl[1:]: |
| chunk['score'] += 1 |
| |
| @staticmethod |
| def _score_chunks_by_subtraction(conflict, chunk_list): |
| # Try to find the conflicting code by comparing the local portion of |
| # the conflict with the removed code in each git chunk. The more lines |
| # that match, the better the score. |
| for chunk in chunk_list: |
| for cl in chunk['chunk']: |
| if not cl.startswith('-'): |
| continue |
| for l in conflict.head_conflict: |
| if l == cl[1:]: |
| chunk['score'] += 1 |
| |
| def blame_remote(self, conflict): |
| """Fetch the git blame output for the remote portion of the conflict. |
| |
| Args: |
| conflict: The Conflict object to assign blame. |
| |
| Returns: |
| String containing the git blame output for the conflict's HEAD text. |
| """ |
| diff = self._git.commit_diff(conflict.sha, self._path) |
| chunk_list = self._get_diff_chunks(diff) |
| |
| # We have to find the chunk in the diff which caused the conflict. |
| # This will give us the line number of the change which we can use to |
| # narrow down the blame range to the relevant bit. |
| if len(chunk_list) == 1: |
| # Only one chunk means this is the portion causing the conflict. |
| results = chunk_list |
| else: |
| # This is a bit tricky since there's no sure way to map the local |
| # conflict into a diff chunk. For now we'll try a few methods to |
| # find the right snippet of code. The chunks with the highest |
| # scores (ties are allowed) get displayed to the user. |
| self._score_chunks_by_identifier(conflict, chunk_list) |
| self._score_chunks_by_addition(conflict, chunk_list) |
| self._score_chunks_by_subtraction(conflict, chunk_list) |
| |
| results = [] |
| for c in chunk_list: |
| if c['score'] > 0: |
| results.append(c) |
| |
| if not results: |
| return 'Could not map the local conflict to remote blame.' |
| |
| results = sorted(results, key=lambda x: x['score'], reverse=True) |
| max_score = results[0]['score'] |
| |
| ret = '' |
| blame = self._git.blame(self._path, f'{conflict.sha}^').splitlines() |
| for i, c in enumerate(results): |
| if c['score'] < max_score: |
| break |
| |
| ret += f'>>>>> Possible result {i}, score={c["score"]}\n' |
| ret += '-- diff\n' |
| ret += '\n'.join(c['chunk']) |
| ret += '\n' |
| ret += '-- blame\n' |
| c_start = max(0, c['old_line']) |
| c_end = min(len(blame), c['old_line'] + c['old_num']) |
| ret += '\n'.join(blame[c_start:c_end]) |
| ret += '\n' |
| |
| return ret |