view upreckon/problem.py @ 205:166a23999bf7

Added confvar okexitcodemask; changed the validator protocol Callable validators now return three-tuples (number granted, bool correct, str comment) instead of two-tuples (number granted, str comment). They are still allowed to return single numbers. Callable validators must now explicitly raise upreckon.exceptions.WrongAnswer if they want the verdict to be Wrong Answer rather than Partly Correct. okexitcodemask specifies a bitmask ANDed with the exit code of the external validator to get a boolean flag showing whether the answer is to be marked as 'OK' rather than 'partly correct'. The bits covered by the bitmask are reset to zeroes before devising the number of points granted from the resulting number.
author Oleg Oshmyan <chortos@inbox.lv>
date Wed, 17 Aug 2011 20:44:54 +0300
parents 67088c1765b4
children c03a8113685d
line wrap: on
line source

# Copyright (c) 2010-2011 Chortos-2 <chortos@inbox.lv>

from __future__ import division, with_statement

from .compat import *
from .exceptions import *
from . import config, files, testcases
from __main__ import options

import os, re, sys

try:
	from collections import deque
except ImportError:
	deque = list

try:
	import signal
except ImportError:
	signalnames = ()
else:
	# Construct a cache of all signal names available on the current
	# platform. Prefer names from the UNIX standards over other versions.
	unixnames = frozenset(('HUP', 'INT', 'QUIT', 'ILL', 'ABRT', 'FPE', 'KILL', 'SEGV', 'PIPE', 'ALRM', 'TERM', 'USR1', 'USR2', 'CHLD', 'CONT', 'STOP', 'TSTP', 'TTIN', 'TTOU', 'BUS', 'POLL', 'PROF', 'SYS', 'TRAP', 'URG', 'VTALRM', 'XCPU', 'XFSZ'))
	signalnames = {}
	for name in dir(signal):
		if re.match('SIG[A-Z]+$', name):
			value = signal.__dict__[name]
			if isinstance(value, int) and (value not in signalnames or name[3:] in unixnames):
				signalnames[value] = name
	del unixnames

__all__ = 'Problem', 'TestContext', 'test_context_end', 'TestGroup'


def strerror(e):
	s = getattr(e, 'strerror', None)
	if not s: s = str(e)
	return ' (%s%s)' % (s[0].lower(), s[1:]) if s else ''


class Cache(object):
	def __init__(self, mydict):
		self.__dict__ = mydict


class TestContext(object):
	__slots__ = ()

test_context_end = object()

class TestGroup(TestContext):
	__slots__ = 'points', 'case', 'log', 'correct', 'allcorrect', 'real', 'max', 'ntotal', 'nvalued', 'ncorrect', 'ncorrectvalued'
	
	def __init__(self, points=None):
		self.points = points
		self.real = self.max = self.ntotal = self.nvalued = self.ncorrect = self.ncorrectvalued = 0
		self.allcorrect = True
		self.log = []
	
	def case_start(self, case):
		self.case = case
		self.correct = False
		self.ntotal += 1
		if case.points:
			self.nvalued += 1
	
	def case_correct(self):
		self.correct = True
		self.ncorrect += 1
		if self.case.points:
			self.ncorrectvalued += 1
	
	def case_end(self):
		self.log.append((self.case, self.correct))
		del self.case
		if not self.correct:
			self.allcorrect = False
	
	def score(self, real, max):
		self.real += real
		self.max += max
	
	def end(self):
		if not self.allcorrect:
			self.real = 0
		if self.points is not None and self.points != self.max:
			max, weighted = self.points, self.real * self.points / self.max if self.max else 0
			before_weighting = ' (%g/%g before weighting)' % (self.real, self.max)
		else:
			max, weighted = self.max, self.real
			before_weighting = ''
		say('Group total: %d/%d tests, %g/%g points%s' % (self.ncorrect, self.ntotal, weighted, max, before_weighting))
		# No real need to flush stdout, as it will anyway be flushed in a moment,
		# when either the problem total or the next test case's ID is printed
		return weighted, max, self.log

class DummyTestGroup(TestGroup):
	__slots__ = ()
	def end(self):
		say('Sample total: %d/%d tests' % (self.ncorrect, self.ntotal))
		return 0, 0, self.log


class Problem(object):
	__slots__ = 'name', 'config', 'cache', 'testcases'
	
	def __init__(prob, name):
		if not isinstance(name, basestring):
			# This shouldn't happen, of course
			raise TypeError('Problem() argument 1 must be string, not ' + type(name).__name__)
		prob.name = name
		prob.config = config.load_problem(name)
		prob.cache = Cache({'padoutput': 0})
		prob.testcases = load_testcases(prob)
	
	# TODO
	def build(prob):
		raise NotImplementedError
	
	def test(prob):
		case = None
		try:
			contexts = deque((TestGroup(),))
			for case in prob.testcases:
				if case is test_context_end:
					real, max, log = contexts.pop().end()
					for case, correct in log:
						contexts[-1].case_start(case)
						if correct:
							contexts[-1].case_correct()
						contexts[-1].case_end()
					contexts[-1].score(real, max)
					continue
				elif isinstance(case, TestContext):
					contexts.append(case)
					continue
				contexts[-1].case_start(case)
				granted = 0
				id = str(case.id)
				if case.isdummy:
					id = 'sample ' + id
				say('%*s: ' % (prob.cache.padoutput, id), end='')
				sys.stdout.flush()
				try:
					if prob.config.kind != 'outonly':
						granted = case(lambda: (say('%7.3f%s s, ' % (case.time_stopped - case.time_started, case.time_limit_string), end=''), sys.stdout.flush()))
					else:
						granted = case(lambda: None)
				except TestCaseSkipped:
					verdict = 'skipped due to skimming mode'
				except CanceledByUser:
					verdict = 'canceled by the user'
				except WallTimeLimitExceeded:
					verdict = 'wall-clock time limit exceeded'
				except CPUTimeLimitExceeded:
					verdict = 'CPU time limit exceeded'
				except MemoryLimitExceeded:
					verdict = 'memory limit exceeded'
				except WrongAnswer:
					e = sys.exc_info()[1]
					if e.comment:
						verdict = 'wrong answer (%s)' % e.comment
					else:
						verdict = 'wrong answer'
				except NonZeroExitCode:
					e = sys.exc_info()[1]
					if e.exitcode < 0:
						if sys.platform == 'win32':
							verdict = 'terminated with error 0x%X' % (e.exitcode + 0x100000000)
						elif -e.exitcode in signalnames:
							verdict = 'terminated by signal %d (%s)' % (-e.exitcode, signalnames[-e.exitcode])
						else:
							verdict = 'terminated by signal %d' % -e.exitcode
					else:
						verdict = 'non-zero return code %d' % e.exitcode
				except CannotStartTestee:
					verdict = 'cannot launch the program to test%s' % strerror(sys.exc_info()[1].upstream)
				except CannotStartValidator:
					verdict = 'cannot launch the validator%s' % strerror(sys.exc_info()[1].upstream)
				except CannotReadOutputFile:
					verdict = 'cannot read the output file%s' % strerror(sys.exc_info()[1].upstream)
				except CannotReadInputFile:
					verdict = 'cannot read the input file%s' % strerror(sys.exc_info()[1].upstream)
				except CannotReadAnswerFile:
					verdict = 'cannot read the reference output file%s' % strerror(sys.exc_info()[1].upstream)
				except ExceptionWrapper:
					verdict = 'unspecified reason [this may be a bug in Upreckon]%s' % strerror(sys.exc_info()[1].upstream)
				except TestCaseNotPassed:
					verdict = 'unspecified reason [this may be a bug in Upreckon]%s' % strerror(sys.exc_info()[1])
				#except Exception:
				#	verdict = 'unknown error [this may be a bug in Upreckon]%s' % strerror(sys.exc_info()[1])
				else:
					try:
						granted, correct, comment = granted
					except TypeError:
						comment = ''
						correct = granted >= 1
					else:
						if comment:
							comment = ' (%s)' % comment
					if correct:
						contexts[-1].case_correct()
						prob.testcases.send(True)
						verdict = 'OK' + comment
					else:
						verdict = 'partly correct' + comment
					granted *= case.points
				say('%g/%g, %s' % (granted, case.points, verdict))
				contexts[-1].case_end()
				contexts[-1].score(granted, case.points)
			weighted = contexts[0].real * prob.config.taskweight / contexts[0].max if contexts[0].max else 0
			before_weighting = valued = ''
			if prob.config.taskweight != contexts[0].max:
				before_weighting = ' (%g/%g before weighting)' % (contexts[0].real, contexts[0].max)
			if contexts[0].nvalued != contexts[0].ntotal:
				valued = ' (%d/%d valued)' % (contexts[0].ncorrectvalued, contexts[0].nvalued)
			say('Problem total: %d/%d tests%s, %g/%g points%s' % (contexts[0].ncorrect, contexts[0].ntotal, valued, weighted, prob.config.taskweight, before_weighting))
			sys.stdout.flush()
			return weighted, prob.config.taskweight
		finally:
			if options.erase and case and case.has_iofiles:
				for var in 'in', 'out':
					name = getattr(prob.config, var + 'name')
					if name:
						try:
							os.remove(name)
						except Exception:
							pass
				if case.has_ansfile:
					if prob.config.ansname:
						try:
							os.remove(prob.config.ansname)
						except Exception:
							pass


def load_testcases(prob, _types={'batch'  : testcases.BatchTestCase,
                                 'outonly': testcases.OutputOnlyTestCase}):
	# We will need to iterate over these configuration variables twice
	try:
		len(prob.config.dummies)
	except Exception:
		prob.config.dummies = tuple(prob.config.dummies)
	try:
		len(prob.config.tests)
	except Exception:
		prob.config.tests = tuple(prob.config.tests)
	
	if prob.config.match == 're':
		if not prob.config.usegroups:
			prob.config.tests = prob.config.tests, None
		elif isinstance(prob.config.tests, basestring):
			prob.config.tests = prob.config.tests, 2
		parts = tuple(map(re.escape, prob.config.dummyinname.split('$')))
		probname = re.escape(prob.name) + '/' if prob.name != os.curdir else ''
		path = '%s%s(%s)' % (probname, parts[0], prob.config.dummies)
		path += r'\1'.join(parts[1:])
		prob.config.dummies = regexp(path, None)
		parts = tuple(map(re.escape, prob.config.testcaseinname.split('$')))
		path = '%s%s(%s)' % (probname, parts[0], prob.config.tests[0])
		path += r'\1'.join(parts[1:])
		prob.config.tests = regexp(path, prob.config.tests[1])
	
	if options.legacy:
		prob.config.usegroups = False
		newtests = []
		for name in prob.config.tests:
			# Same here; we'll need to iterate over them twice
			try:
				l = len(name)
			except Exception:
				try:
					name = tuple(name)
				except TypeError:
					name = (name,)
				l = len(name)
			if l > 1:
				prob.config.usegroups = True
			newtests.append(name)
		if prob.config.usegroups:
			prob.config.tests = newtests
		del newtests
	
	# Even if they have duplicate test identifiers, we must honour sequence pointmaps
	if isinstance(prob.config.pointmap, dict):
		def getpoints(i, j, k=None):
			try:
				return prob.config.pointmap[i]
			except KeyError:
				try:
					return prob.config.pointmap[None]
				except KeyError:
					return prob.config.maxexitcode or 1
	elif prob.config.usegroups:
		def getpoints(i, j, k):
			try:
				return prob.config.pointmap[k][j]
			except LookupError:
				return prob.config.maxexitcode or 1
	else:
		def getpoints(i, j):
			try:
				return prob.config.pointmap[j]
			except LookupError:
				return prob.config.maxexitcode or 1
	
	# First get prob.cache.padoutput right,
	# then yield the actual test cases
	for i in prob.config.dummies:
		s = 'sample ' + str(i).zfill(prob.config.paddummies)
		prob.cache.padoutput = max(prob.cache.padoutput, len(s))
	if prob.config.usegroups:
		if not isinstance(prob.config.groupweight, dict):
			prob.config.groupweight = dict(enumerate(prob.config.groupweight))
		for group in prob.config.tests:
			for i in group:
				s = str(i).zfill(prob.config.padtests)
				prob.cache.padoutput = max(prob.cache.padoutput, len(s))
		if prob.config.dummies:
			yield DummyTestGroup()
			for i in prob.config.dummies:
				s = str(i).zfill(prob.config.paddummies)
				if (yield _types[prob.config.kind](prob, s, True, 0)):
					yield
			yield test_context_end
		for k, group in enumerate(prob.config.tests):
			if not group:
				continue
			yield TestGroup(prob.config.groupweight.get(k, prob.config.groupweight.get(None)))
			case_type = _types[prob.config.kind]
			for j, i in enumerate(group):
				s = str(i).zfill(prob.config.padtests)
				if not (yield case_type(prob, s, False, getpoints(i, j, k))):
					if options.skim:
						case_type = testcases.SkippedTestCase
				else:
					yield
			yield test_context_end
	else:
		for i in prob.config.tests:
			s = str(i).zfill(prob.config.padtests)
			prob.cache.padoutput = max(prob.cache.padoutput, len(s))
		for i in prob.config.dummies:
			s = str(i).zfill(prob.config.paddummies)
			if (yield _types[prob.config.kind](prob, s, True, 0)):
				yield
		for j, i in enumerate(prob.config.tests):
			s = str(i).zfill(prob.config.padtests)
			if (yield _types[prob.config.kind](prob, s, False, getpoints(i, j))):
				yield

def regexp(pattern, group):
	reobj = re.compile(pattern, re.UNICODE)
	if not group:
		ids = []
		for f in files.regexp(pattern):
			ids.append(re.match(reobj, f.virtual_path).group(1))
		return natsorted(ids)
	else:
		ids = {}
		for f in files.regexp(pattern):
			m = re.match(reobj, f.virtual_path)
			g = m.group(group)
			ids.setdefault(g, [])
			ids[g].append(m.group(1))
		for g in ids:
			ids[g] = natsorted(ids[g])
		return [ids[g] for g in natsorted(keys(ids))]

def natsorted(l):
	return sorted(l, key=lambda s: [int(t) if t.isdigit() else t for t in re.split('(\d+)', s)])