#!/usr/bin/python -O
# vim: noet :
import os
import re
import signal
import sys
import time
from optparse import OptionParser, OptionValueError

try:
	import portage
except ImportError:
	from os import path as osp
	sys.path.insert(0, osp.join(osp.dirname(osp.dirname(osp.realpath(__file__))), "pym"))
	import portage

import portage.const, portage.exception, portage.output

class WorldHandler(object):

	def name():
		return "world"
	name = staticmethod(name)

	def __init__(self):
		self.invalid = []
		self.not_installed = []
		self.invalid_category = []
		self.okay = []
		from portage.sets import load_default_config
		setconfig = load_default_config(portage.settings,
			portage.db[portage.settings["ROOT"]])
		self._sets = setconfig.getSets()

	def _check_world(self, onProgress):
		categories = set(portage.settings.categories)
		myroot = portage.settings["ROOT"]
		self.world_file = os.path.join(myroot, portage.const.WORLD_FILE)
		self.found = os.access(self.world_file, os.R_OK)
		vardb = portage.db[myroot]["vartree"].dbapi

		from portage.sets import SETPREFIX
		sets = self._sets
		world_atoms = list(sets["world"])
		maxval = len(world_atoms)
		if onProgress:
			onProgress(maxval, 0)
		for i, atom in enumerate(world_atoms):
			if not portage.isvalidatom(atom):
				if atom.startswith(SETPREFIX):
					s = atom[len(SETPREFIX):]
					if s in sets:
						self.okay.append(atom)
					else:
						self.not_installed.append(atom)
				else:
					self.invalid.append(atom)
				if onProgress:
					onProgress(maxval, i+1)
				continue
			cp = portage.dep_getkey(atom)
			okay = True
			if not vardb.match(atom):
				self.not_installed.append(atom)
				okay = False
			if portage.catsplit(cp)[0] not in categories:
				self.invalid_category.append(atom)
				okay = False
			if okay:
				self.okay.append(atom)
			if onProgress:
				onProgress(maxval, i+1)

	def check(self, onProgress=None):
		self._check_world(onProgress)
		errors = []
		if self.found:
			errors += map(lambda x: "'%s' is not a valid atom" % x, self.invalid)
			errors += map(lambda x: "'%s' is not installed" % x, self.not_installed)
			errors += map(lambda x: "'%s' has a category that is not listed in /etc/portage/categories" % x, self.invalid_category)
		else:
			errors.append(self.world_file + " could not be opened for reading")
		return errors

	def fix(self, onProgress=None):
		world_set = self._sets["world"]
		world_set.lock()
		try:
			world_set.load() # maybe it's changed on disk
			before = set(world_set)
			self._check_world(onProgress)
			after = set(self.okay)
			errors = []
			if before != after:
				try:
					world_set.replace(self.okay)
				except portage.exception.PortageException:
					errors.append("%s could not be opened for writing" % \
						self.world_file)
			return errors
		finally:
			world_set.unlock()

class BinhostHandler(object):

	def name():
		return "binhost"
	name = staticmethod(name)

	def __init__(self):
		myroot = portage.settings["ROOT"]
		self._bintree = portage.db[myroot]["bintree"]
		self._bintree.populate()
		self._pkgindex_file = os.path.join(self._bintree.pkgdir, "Packages")
		from portage import getbinpkg
		self._pkgindex = getbinpkg.PackageIndex()
		f = open(self._pkgindex_file, 'r')
		try:
			self._pkgindex.read(f)
		finally:
			f.close()

	def check(self, onProgress=None):
		missing = []
		cpv_all = self._bintree.dbapi.cpv_all()
		cpv_all.sort()
		maxval = len(cpv_all)
		if onProgress:
			onProgress(maxval, 0)
		pkgindex = self._pkgindex
		missing = []
		for i, cpv in enumerate(cpv_all):
			d = pkgindex.packages.get(cpv)
			if not d or "MD5" not in d:
				missing.append(cpv)
			if onProgress:
				onProgress(maxval, i+1)
		errors = ["'%s' is not in Packages" % cpv for cpv in missing]
		stale = set(pkgindex.packages).difference(cpv_all)
		for cpv in stale:
			errors.append("'%s' is not in the repository" % cpv)
		return errors

	def fix(self, onProgress=None):
		bintree = self._bintree
		cpv_all = self._bintree.dbapi.cpv_all()
		cpv_all.sort()
		missing = []
		maxval = len(cpv_all)
		if onProgress:
			onProgress(maxval, 0)
		pkgindex = self._pkgindex
		missing = []
		for i, cpv in enumerate(cpv_all):
			d = pkgindex.packages.get(cpv)
			if not d or "MD5" not in d:
				bintree.inject(cpv)
			if onProgress:
				onProgress(maxval, i+1)
		stale = set(pkgindex.packages).difference(cpv_all)
		if stale:
			from portage import locks
			pkgindex_lock = locks.lockfile(
				self._pkgindex_file, wantnewlockfile=1)
			try:
				from portage import getbinpkg
				del pkgindex
				self._pkgindex = getbinpkg.PackageIndex()
				f = open(self._pkgindex_file, 'r')
				try:
					self._pkgindex.read(f)
				finally:
					f.close()
				from portage.dbapi.bintree import binarytree
				self._bintree = binarytree(bintree.root, bintree.pkgdir,
					settings=bintree.settings)
				del bintree
				portage.db[self._bintree.root]["bintree"] = self._bintree
				self._bintree._populate()
				for cpv in set(self._pkgindex.packages).difference(
					self._bintree.dbapi.cpv_all()):
					del self._pkgindex.packages[cpv]
				from portage.util import atomic_ofstream
				f = atomic_ofstream(self._pkgindex_file)
				try:
					self._pkgindex.write(f)
				finally:
					f.close()
			finally:
				locks.unlockfile(pkgindex_lock)
		return None

class MoveHandler(object):

	def __init__(self, tree):
		self._tree = tree
		self._portdir = tree.settings["PORTDIR"]
		self._update_keys = ["DEPEND", "RDEPEND", "PDEPEND", "PROVIDE"]

	def _grab_global_updates(self, portdir):
		from portage.update import grab_updates, parse_updates
		updpath = os.path.join(portdir, "profiles", "updates")
		try:
			rawupdates = grab_updates(updpath)
		except portage.exception.DirectoryNotFound:
			rawupdates = []
		upd_commands = []
		errors = []
		for mykey, mystat, mycontent in rawupdates:
			commands, errors = parse_updates(mycontent)
			upd_commands.extend(commands)
			errors.extend(errors)
		return upd_commands, errors

	def check(self, onProgress=None):
		updates, errors = self._grab_global_updates(self._portdir)
		# Matching packages and moving them is relatively fast, so the
		# progress bar is updated in indeterminate mode.
		match = self._tree.dbapi.match
		aux_get = self._tree.dbapi.aux_get
		if onProgress:
			onProgress(0, 0)
		for i, update_cmd in enumerate(updates):
			if update_cmd[0] == "move":
				origcp, newcp = update_cmd[1:]
				for cpv in match(origcp):
					errors.append("'%s' moved to '%s'" % (cpv, newcp))
			elif update_cmd[0] == "slotmove":
				pkg, origslot, newslot = update_cmd[1:]
				for cpv in match(pkg):
					slot = aux_get(cpv, ["SLOT"])[0]
					if slot == origslot:
						errors.append("'%s' slot moved from '%s' to '%s'" % \
							(cpv, origslot, newslot))
			if onProgress:
				onProgress(0, 0)

		# Searching for updates in all the metadata is relatively slow, so this
		# is where the progress bar comes out of indeterminate mode.
		cpv_all = self._tree.dbapi.cpv_all()
		cpv_all.sort()
		maxval = len(cpv_all)
		aux_update = self._tree.dbapi.aux_update
		update_keys = self._update_keys
		from itertools import izip
		from portage.update import update_dbentries
		if onProgress:
			onProgress(maxval, 0)
		for i, cpv in enumerate(cpv_all):
			metadata = dict(izip(update_keys, aux_get(cpv, update_keys)))
			metadata_updates = update_dbentries(updates, metadata)
			if metadata_updates:
				errors.append("'%s' has outdated metadata" % cpv)
			if onProgress:
				onProgress(maxval, i+1)
		return errors

	def fix(self, onProgress=None):
		updates, errors = self._grab_global_updates(self._portdir)
		# Matching packages and moving them is relatively fast, so the
		# progress bar is updated in indeterminate mode.
		move = self._tree.dbapi.move_ent
		slotmove = self._tree.dbapi.move_slot_ent
		if onProgress:
			onProgress(0, 0)
		for i, update_cmd in enumerate(updates):
			if update_cmd[0] == "move":
				move(update_cmd)
			elif update_cmd[0] == "slotmove":
				slotmove(update_cmd)
			if onProgress:
				onProgress(0, 0)

		# Searching for updates in all the metadata is relatively slow, so this
		# is where the progress bar comes out of indeterminate mode.
		self._tree.dbapi.update_ents(updates, onProgress=onProgress)
		return errors

class MoveInstalled(MoveHandler):
	def name():
		return "moveinst"
	name = staticmethod(name)
	def __init__(self):
		myroot = portage.settings["ROOT"]
		MoveHandler.__init__(self, portage.db[myroot]["vartree"])

class MoveBinary(MoveHandler):
	def name():
		return "movebin"
	name = staticmethod(name)
	def __init__(self):
		myroot = portage.settings["ROOT"]
		MoveHandler.__init__(self, portage.db[myroot]["bintree"])

class VdbKeyHandler(object):
	def name():
		return "vdbkeys"
	name = staticmethod(name)

	def __init__(self):
		self.list = portage.db["/"]["vartree"].dbapi.cpv_all()
		self.missing = []
		self.keys = ["HOMEPAGE", "SRC_URI", "KEYWORDS", "DESCRIPTION"]
		
		for p in self.list:
			mydir = os.path.join(os.sep, portage.settings["ROOT"], portage.const.VDB_PATH, p)+os.sep
			ismissing = True
			for k in self.keys:
				if os.path.exists(mydir+k):
					ismissing = False
					break
			if ismissing:
				self.missing.append(p)
		
	def check(self):
		return ["%s has missing keys" % x for x in self.missing]
	
	def fix(self):
	
		errors = []
	
		for p in self.missing:
			mydir = os.path.join(os.sep, portage.settings["ROOT"], portage.const.VDB_PATH, p)+os.sep
			if not os.access(mydir+"environment.bz2", os.R_OK):
				errors.append("Can't access %s" % (mydir+"environment.bz2"))
			elif not os.access(mydir, os.W_OK):
				errors.append("Can't create files in %s" % mydir)
			else:
				env = os.popen("bzip2 -dcq "+mydir+"environment.bz2", "r")
				envlines = env.read().split("\n")
				env.close()
				for k in self.keys:
					s = [l for l in envlines if l.startswith(k+"=")]
					if len(s) > 1:
						errors.append("multiple matches for %s found in %senvironment.bz2" % (k, mydir))
					elif len(s) == 0:
						s = ""
					else:
						s = s[0].split("=",1)[1]
						s = s.lstrip("$").strip("\'\"")
						s = re.sub("(\\\\[nrt])+", " ", s)
						s = " ".join(s.split()).strip()
						if s != "":
							try:
								keyfile = open(mydir+os.sep+k, "w")
								keyfile.write(s+"\n")
								keyfile.close()
							except (IOError, OSError), e:
								errors.append("Could not write %s, reason was: %s" % (mydir+k, e))
		
		return errors

class ProgressHandler(object):
	def __init__(self):
		self.curval = 0
		self.maxval = 0
		self.last_update = 0
		self.min_display_latency = 0.2

	def onProgress(self, maxval, curval):
		self.maxval = maxval
		self.curval = curval
		cur_time = time.time()
		if cur_time - self.last_update >= self.min_display_latency:
			self.last_update = cur_time
			self.display()

	def display(self):
		raise NotImplementedError(self)

class CleanResume(object):
	def name():
		return "cleanresume"
	name = staticmethod(name)

	def check(self, onProgress=None):
		messages = []
		mtimedb = portage.mtimedb
		resume_keys = ("resume", "resume_backup")
		maxval = len(resume_keys)
		if onProgress:
			onProgress(maxval, 0)
		for i, k in enumerate(resume_keys):
			try:
				d = mtimedb.get(k)
				if d is None:
					continue
				if not isinstance(d, dict):
					messages.append("unrecognized resume list: '%s'" % k)
					continue
				mergelist = d.get("mergelist")
				if mergelist is None or not hasattr(mergelist, "__len__"):
					messages.append("unrecognized resume list: '%s'" % k)
					continue
				messages.append("resume list '%s' contains %d packages" % \
					(k, len(mergelist)))
			finally:
				if onProgress:
					onProgress(maxval, i+1)
		return messages

	def fix(self, onProgress=None):
		delete_count = 0
		mtimedb = portage.mtimedb
		resume_keys = ("resume", "resume_backup")
		maxval = len(resume_keys)
		if onProgress:
			onProgress(maxval, 0)
		for i, k in enumerate(resume_keys):
			try:
				if mtimedb.pop(k, None) is not None:
					delete_count += 1
			finally:
				if onProgress:
					onProgress(maxval, i+1)
		if delete_count:
			mtimedb.commit()

def emaint_main(myargv):

	# Similar to emerge, emaint needs a default umask so that created
	# files (such as the world file) have sane permissions.
	os.umask(022)

	# TODO: Create a system that allows external modules to be added without
	#       the need for hard coding.
	modules = {
		"world" : WorldHandler,
		"binhost":BinhostHandler,
		"moveinst":MoveInstalled,
		"movebin":MoveBinary,
		"cleanresume":CleanResume
	}

	module_names = modules.keys()
	module_names.sort()
	module_names.insert(0, "all")

	def exclusive(option, *args, **kw):
		var = kw.get("var", None)
		if var is None:
			raise ValueError("var not specified to exclusive()")
		if getattr(parser, var, ""):
			raise OptionValueError("%s and %s are exclusive options" % (getattr(parser, var), option))
		setattr(parser, var, str(option))


	usage = "usage: emaint [options] " + " | ".join(module_names)

	usage+= "\n\nCurrently emaint can only check and fix problems with one's world\n"
	usage+= "file.  Future versions will integrate other portage check-and-fix\n"
	usage+= "tools and provide a single interface to system health checks."


	parser = OptionParser(usage=usage, version=portage.VERSION)
	parser.add_option("-c", "--check", help="check for problems",
		action="callback", callback=exclusive, callback_kwargs={"var":"action"})
	parser.add_option("-f", "--fix", help="attempt to fix problems",
		action="callback", callback=exclusive, callback_kwargs={"var":"action"})
	parser.action = None


	(options, args) = parser.parse_args(args=myargv)
	if len(args) != 1:
		parser.error("Incorrect number of arguments")
	if args[0] not in module_names:
		parser.error("%s target is not a known target" % args[0])

	if parser.action:
		action = parser.action
	else:
		print "Defaulting to --check"
		action = "-c/--check"

	if args[0] == "all":
		tasks = modules.values()
	else:
		tasks = [modules[args[0]]]


	if action == "-c/--check":
		status = "Checking %s for problems"
		func = "check"
	else:
		status = "Attempting to fix %s"
		func = "fix"

	isatty = sys.stdout.isatty()
	for task in tasks:
		print status % task.name()
		inst = task()
		onProgress = None
		if isatty:
			progressBar = portage.output.TermProgressBar()
			progressHandler = ProgressHandler()
			onProgress = progressHandler.onProgress
			def display():
				progressBar.set(progressHandler.curval, progressHandler.maxval)
			progressHandler.display = display
			def sigwinch_handler(signum, frame):
				lines, progressBar.term_columns = \
					portage.output.get_term_size()
			signal.signal(signal.SIGWINCH, sigwinch_handler)
		result = getattr(inst, func)(onProgress=onProgress)
		if isatty:
			# make sure the final progress is displayed
			progressHandler.display()
			print
			signal.signal(signal.SIGWINCH, signal.SIG_DFL)
		if result:
			print
			print "\n".join(result)
			print "\n"

	print "Finished"

if __name__ == "__main__":
	emaint_main(sys.argv[1:])
