blob: 544e3fd07ad487177c6d32c8610410907c21f06d [file] [log] [blame]
# Copyright 2021 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Provides help with the conflict resolution feature of forklift"""
import re
class Conflict:
"""Represents a conflict located in a local file.
Attributes:
sha: The upstream commit sha involved in the conflict.
subject: The upstream commit subject.
head_confict: List of the head lines involved in the conflict.
remote_conflict: The lines from the remote which are conflicting.
"""
def __init__(self, head=None, separator=None, remote=None):
"""Initialize the conflict object.
Args:
head: The line number of the '<<<<<<< HEAD' sentinel.
separator: The line number of the '=======' sentinel.
remote: The line number of the '>>>>>>> <sha>...<subject' sentinel.
"""
self._head = None
self._separator = None
self._remote = None
self.sha = None
self.subject = None
self.head_conflict = []
self.remote_conflict = []
self._set_head(head)
self._set_separator(separator)
self._set_remote(remote)
@staticmethod
def _valid_conflict(head, separator, remote):
valid = True
if separator:
valid = valid and head and head < separator
if remote:
valid = valid and separator and separator < remote
return valid
def _set_head(self, head):
if not self._valid_conflict(head, self._separator, self._remote):
raise ValueError((f'Conflict {head}/{self._separator}/'
f'{self._remote} invalid.'))
self._head = head
def _set_separator(self, separator):
if not self._valid_conflict(self._head, separator, self._remote):
raise ValueError((f'Conflict {self._head}/{separator}/'
f'{self._remote} invalid.'))
self._separator = separator
def _set_remote(self, remote):
if not self._valid_conflict(self._head, self._separator, remote):
raise ValueError((f'Conflict{self._head}/{self._separator}/'
f'{remote} invalid.'))
self._remote = remote
def head(self):
"""Returns the line number of '<<<<<<< HEAD' for the conflict."""
return self._head
def separator(self):
"""Returns the line number of '=======' for the conflict."""
return self._separator
def remote(self):
"""Returns the line number of '>>>>>>>' for the conflict."""
return self._remote
def parse(self, line_num, line):
"""Parses the line and adds it to the internal state if applicable.
Args:
line_num: The number of the line being parsed.
line: The contents of the current line being parsed.
Returns:
True if the conflict has been completely parsed, False otherwise.
"""
if line.startswith('<<<<<<<'):
self._set_head(line_num)
elif line.startswith('======='):
self._set_separator(line_num)
elif line.startswith('>>>>>>>'):
self._set_remote(line_num)
m = re.match(r'>>>>>>> ([a-f0-9]+)(\.{3})? \(?(.+)\)?\n', line)
self.sha = m.group(1)
self.subject = m.group(3)
return True
elif self._head and not self._separator:
self.head_conflict.append(line.rstrip())
elif self._separator and not self._remote:
self.remote_conflict.append(line.rstrip())
return False
class Resolver:
"""Class to assist in resolving a conflict."""
def __init__(self, git, path):
"""Initialize the Resolver class.
Args:
git: The Git object to use for git operations.
path: The path of the file containing the conflicts to resolve.
"""
self._git = git
self._path = path
def get_conflicts(self):
"""Returns a list of conflicts from the file at the given path.
Parses the file at self._path and pulls out all the conflicts
into Conflict objects. Returns a list of those conflicts.
Args:
path: The path to the conflicting file.
Returns:
A list of conflicts from the given file.
"""
conflicts = []
with open(self._path, mode='r') as f:
cur_conflict = Conflict()
line_num = 1
for l in f:
if cur_conflict.parse(line_num, l):
conflicts.append(cur_conflict)
cur_conflict = Conflict()
line_num += 1
return conflicts
@staticmethod
def _format_conflict_line(line_num, line):
return f'{line_num:<5} {line}\n'
def format_conflict(self, conflict, print_head=True, print_remote=True):
"""Formats the conflict in a human-readable format.
Args:
conflict: The Conflict object to format.
print_head: True if output should contain the HEAD portion.
print_remote: True if output should contain the remote portion.
Returns:
The formatted conflict in a string.
"""
ret = ''
if print_head:
ret += self._format_conflict_line(conflict.head(), '<<<<<<< HEAD')
for i, l in enumerate(conflict.head_conflict):
ret += self._format_conflict_line(i + 1 + conflict.head(), l)
ret += self._format_conflict_line(conflict.separator(), '=======')
if print_remote:
for i, l in enumerate(conflict.remote_conflict):
ret += self._format_conflict_line(i + 1 + conflict.separator(),
l)
ret += self._format_conflict_line(conflict.remote(),
f'>>>>>>> {conflict.sha}.. {conflict.subject}')
return ret
def blame_head(self, conflict):
"""Fetch the git blame output for the HEAD portion of the conflict.
Args:
conflict: The Conflict object to assign blame.
Returns:
String containing the git blame output for the conflict's HEAD text.
"""
blame = self._git.blame(self._path).splitlines()
# 3 lines of context on either side
start = max(0, conflict.head() - 4)
end = conflict.head() - 1
ret = '\n'.join(blame[start:end])
ret += '\n'
start = conflict.head()
end = conflict.separator() - 1
ret += '\n'.join(blame[start:end])
ret += '\n'
start = min(len(blame), conflict.remote())
end = min(len(blame), conflict.remote() + 3)
ret += '\n'.join(blame[start:end])
return ret
@staticmethod
def _get_diff_chunks(diff):
re_chunk = re.compile((r'@@ -([0-9]+),([0-9]+) \+([0-9]+),([0-9]+) @@'
'(.*)?'))
chunks = []
cur_chunk = None
for l in diff.splitlines():
m = re_chunk.match(l)
if m:
if cur_chunk:
chunks.append(cur_chunk)
cur_chunk = {'old_line': int(m.group(1)),
'old_num': int(m.group(2)),
'new_line': int(m.group(3)),
'new_num': int(m.group(4)),
'identifier': m.group(5) if m.group(5) else 'NA',
'chunk': [],
'score': 0}
else:
cur_chunk['chunk'].append(l.rstrip())
if cur_chunk:
chunks.append(cur_chunk)
return chunks
def _score_chunks_by_identifier(self, conflict, chunk_list):
# Walk through the local file backwards starting at the conflict
# looking for the first identifier also showing up in the git diff
# output for the conflicting change. Score one point to any chunk with
# the same identifier
with open(self._path, 'r') as f:
lines = f.readlines()
identifier = None
for l in reversed(lines[:conflict.head()]):
for c in chunk_list:
if l.rstrip() == c['identifier']:
identifier = c['identifier']
break
if not identifier:
return
for c in chunk_list:
if c['identifier'] == identifier:
c['score'] += 1
@staticmethod
def _score_chunks_by_addition(conflict, chunk_list):
# Try to find the conflicting code by comparing the remote portion of
# the conflict with the added code in each git chunk. The more lines
# that match, the better the score.
for chunk in chunk_list:
for cl in chunk['chunk']:
if not cl.startswith('+'):
continue
for l in conflict.remote_conflict:
if l == cl[1:]:
chunk['score'] += 1
@staticmethod
def _score_chunks_by_subtraction(conflict, chunk_list):
# Try to find the conflicting code by comparing the local portion of
# the conflict with the removed code in each git chunk. The more lines
# that match, the better the score.
for chunk in chunk_list:
for cl in chunk['chunk']:
if not cl.startswith('-'):
continue
for l in conflict.head_conflict:
if l == cl[1:]:
chunk['score'] += 1
def blame_remote(self, conflict):
"""Fetch the git blame output for the remote portion of the conflict.
Args:
conflict: The Conflict object to assign blame.
Returns:
String containing the git blame output for the conflict's HEAD text.
"""
diff = self._git.commit_diff(conflict.sha, self._path)
chunk_list = self._get_diff_chunks(diff)
# We have to find the chunk in the diff which caused the conflict.
# This will give us the line number of the change which we can use to
# narrow down the blame range to the relevant bit.
if len(chunk_list) == 1:
# Only one chunk means this is the portion causing the conflict.
results = chunk_list
else:
# This is a bit tricky since there's no sure way to map the local
# conflict into a diff chunk. For now we'll try a few methods to
# find the right snippet of code. The chunks with the highest
# scores (ties are allowed) get displayed to the user.
self._score_chunks_by_identifier(conflict, chunk_list)
self._score_chunks_by_addition(conflict, chunk_list)
self._score_chunks_by_subtraction(conflict, chunk_list)
results = []
for c in chunk_list:
if c['score'] > 0:
results.append(c)
if not results:
return 'Could not map the local conflict to remote blame.'
results = sorted(results, key=lambda x: x['score'], reverse=True)
max_score = results[0]['score']
ret = ''
blame = self._git.blame(self._path, f'{conflict.sha}^').splitlines()
for i, c in enumerate(results):
if c['score'] < max_score:
break
ret += f'>>>>> Possible result {i}, score={c["score"]}\n'
ret += '-- diff\n'
ret += '\n'.join(c['chunk'])
ret += '\n'
ret += '-- blame\n'
c_start = max(0, c['old_line'])
c_end = min(len(blame), c['old_line'] + c['old_num'])
ret += '\n'.join(blame[c_start:c_end])
ret += '\n'
return ret