view upreckon/problem.py @ 196:67088c1765b4

Regexps now work with test archives Excuse me while I rewrite files.{File,regexp} almost from scratch...
author Oleg Oshmyan <chortos@inbox.lv>
date Mon, 15 Aug 2011 19:52:58 +0300
parents a76cdc26ba9d
children 166a23999bf7
line wrap: on
line source

# Copyright (c) 2010-2011 Chortos-2 <chortos@inbox.lv>

from __future__ import division, with_statement

from .compat import *
from .exceptions import *
from . import config, files, testcases
from __main__ import options

import os, re, sys

try:
	from collections import deque
except ImportError:
	deque = list

try:
	import signal
except ImportError:
	signalnames = ()
else:
	# Construct a cache of all signal names available on the current
	# platform. Prefer names from the UNIX standards over other versions.
	unixnames = frozenset(('HUP', 'INT', 'QUIT', 'ILL', 'ABRT', 'FPE', 'KILL', 'SEGV', 'PIPE', 'ALRM', 'TERM', 'USR1', 'USR2', 'CHLD', 'CONT', 'STOP', 'TSTP', 'TTIN', 'TTOU', 'BUS', 'POLL', 'PROF', 'SYS', 'TRAP', 'URG', 'VTALRM', 'XCPU', 'XFSZ'))
	signalnames = {}
	for name in dir(signal):
		if re.match('SIG[A-Z]+$', name):
			value = signal.__dict__[name]
			if isinstance(value, int) and (value not in signalnames or name[3:] in unixnames):
				signalnames[value] = name
	del unixnames

__all__ = 'Problem', 'TestContext', 'test_context_end', 'TestGroup'


def strerror(e):
	s = getattr(e, 'strerror', None)
	if not s: s = str(e)
	return ' (%s%s)' % (s[0].lower(), s[1:]) if s else ''


class Cache(object):
	def __init__(self, mydict):
		self.__dict__ = mydict


class TestContext(object):
	__slots__ = ()

test_context_end = object()

class TestGroup(TestContext):
	__slots__ = 'points', 'case', 'log', 'correct', 'allcorrect', 'real', 'max', 'ntotal', 'nvalued', 'ncorrect', 'ncorrectvalued'
	
	def __init__(self, points=None):
		self.points = points
		self.real = self.max = self.ntotal = self.nvalued = self.ncorrect = self.ncorrectvalued = 0
		self.allcorrect = True
		self.log = []
	
	def case_start(self, case):
		self.case = case
		self.correct = False
		self.ntotal += 1
		if case.points:
			self.nvalued += 1
	
	def case_correct(self):
		self.correct = True
		self.ncorrect += 1
		if self.case.points:
			self.ncorrectvalued += 1
	
	def case_end(self):
		self.log.append((self.case, self.correct))
		del self.case
		if not self.correct:
			self.allcorrect = False
	
	def score(self, real, max):
		self.real += real
		self.max += max
	
	def end(self):
		if not self.allcorrect:
			self.real = 0
		if self.points is not None and self.points != self.max:
			max, weighted = self.points, self.real * self.points / self.max if self.max else 0
			before_weighting = ' (%g/%g before weighting)' % (self.real, self.max)
		else:
			max, weighted = self.max, self.real
			before_weighting = ''
		say('Group total: %d/%d tests, %g/%g points%s' % (self.ncorrect, self.ntotal, weighted, max, before_weighting))
		# No real need to flush stdout, as it will anyway be flushed in a moment,
		# when either the problem total or the next test case's ID is printed
		return weighted, max, self.log

class DummyTestGroup(TestGroup):
	__slots__ = ()
	def end(self):
		say('Sample total: %d/%d tests' % (self.ncorrect, self.ntotal))
		return 0, 0, self.log


class Problem(object):
	__slots__ = 'name', 'config', 'cache', 'testcases'
	
	def __init__(prob, name):
		if not isinstance(name, basestring):
			# This shouldn't happen, of course
			raise TypeError('Problem() argument 1 must be string, not ' + type(name).__name__)
		prob.name = name
		prob.config = config.load_problem(name)
		prob.cache = Cache({'padoutput': 0})
		prob.testcases = load_testcases(prob)
	
	# TODO
	def build(prob):
		raise NotImplementedError
	
	def test(prob):
		case = None
		try:
			contexts = deque((TestGroup(),))
			for case in prob.testcases:
				if case is test_context_end:
					real, max, log = contexts.pop().end()
					for case, correct in log:
						contexts[-1].case_start(case)
						if correct:
							contexts[-1].case_correct()
						contexts[-1].case_end()
					contexts[-1].score(real, max)
					continue
				elif isinstance(case, TestContext):
					contexts.append(case)
					continue
				contexts[-1].case_start(case)
				granted = 0
				id = str(case.id)
				if case.isdummy:
					id = 'sample ' + id
				say('%*s: ' % (prob.cache.padoutput, id), end='')
				sys.stdout.flush()
				try:
					if prob.config.kind != 'outonly':
						granted = case(lambda: (say('%7.3f%s s, ' % (case.time_stopped - case.time_started, case.time_limit_string), end=''), sys.stdout.flush()))
					else:
						granted = case(lambda: None)
				except TestCaseSkipped:
					verdict = 'skipped due to skimming mode'
				except CanceledByUser:
					verdict = 'canceled by the user'
				except WallTimeLimitExceeded:
					verdict = 'wall-clock time limit exceeded'
				except CPUTimeLimitExceeded:
					verdict = 'CPU time limit exceeded'
				except MemoryLimitExceeded:
					verdict = 'memory limit exceeded'
				except WrongAnswer:
					e = sys.exc_info()[1]
					if e.comment:
						verdict = 'wrong answer (%s)' % e.comment
					else:
						verdict = 'wrong answer'
				except NonZeroExitCode:
					e = sys.exc_info()[1]
					if e.exitcode < 0:
						if sys.platform == 'win32':
							verdict = 'terminated with error 0x%X' % (e.exitcode + 0x100000000)
						elif -e.exitcode in signalnames:
							verdict = 'terminated by signal %d (%s)' % (-e.exitcode, signalnames[-e.exitcode])
						else:
							verdict = 'terminated by signal %d' % -e.exitcode
					else:
						verdict = 'non-zero return code %d' % e.exitcode
				except CannotStartTestee:
					verdict = 'cannot launch the program to test%s' % strerror(sys.exc_info()[1].upstream)
				except CannotStartValidator:
					verdict = 'cannot launch the validator%s' % strerror(sys.exc_info()[1].upstream)
				except CannotReadOutputFile:
					verdict = 'cannot read the output file%s' % strerror(sys.exc_info()[1].upstream)
				except CannotReadInputFile:
					verdict = 'cannot read the input file%s' % strerror(sys.exc_info()[1].upstream)
				except CannotReadAnswerFile:
					verdict = 'cannot read the reference output file%s' % strerror(sys.exc_info()[1].upstream)
				except ExceptionWrapper:
					verdict = 'unspecified reason [this may be a bug in Upreckon]%s' % strerror(sys.exc_info()[1].upstream)
				except TestCaseNotPassed:
					verdict = 'unspecified reason [this may be a bug in Upreckon]%s' % strerror(sys.exc_info()[1])
				#except Exception:
				#	verdict = 'unknown error [this may be a bug in Upreckon]%s' % strerror(sys.exc_info()[1])
				else:
					try:
						granted, comment = granted
					except TypeError:
						comment = ''
					else:
						if comment:
							comment = ' (%s)' % comment
					if granted >= 1:
						contexts[-1].case_correct()
						prob.testcases.send(True)
						verdict = 'OK' + comment
					elif not granted:
						verdict = 'wrong answer' + comment
					else:
						verdict = 'partly correct' + comment
					granted *= case.points
				say('%g/%g, %s' % (granted, case.points, verdict))
				contexts[-1].case_end()
				contexts[-1].score(granted, case.points)
			weighted = contexts[0].real * prob.config.taskweight / contexts[0].max if contexts[0].max else 0
			before_weighting = valued = ''
			if prob.config.taskweight != contexts[0].max:
				before_weighting = ' (%g/%g before weighting)' % (contexts[0].real, contexts[0].max)
			if contexts[0].nvalued != contexts[0].ntotal:
				valued = ' (%d/%d valued)' % (contexts[0].ncorrectvalued, contexts[0].nvalued)
			say('Problem total: %d/%d tests%s, %g/%g points%s' % (contexts[0].ncorrect, contexts[0].ntotal, valued, weighted, prob.config.taskweight, before_weighting))
			sys.stdout.flush()
			return weighted, prob.config.taskweight
		finally:
			if options.erase and case and case.has_iofiles:
				for var in 'in', 'out':
					name = getattr(prob.config, var + 'name')
					if name:
						try:
							os.remove(name)
						except Exception:
							pass
				if case.has_ansfile:
					if prob.config.ansname:
						try:
							os.remove(prob.config.ansname)
						except Exception:
							pass


def load_testcases(prob, _types={'batch'  : testcases.BatchTestCase,
                                 'outonly': testcases.OutputOnlyTestCase}):
	# We will need to iterate over these configuration variables twice
	try:
		len(prob.config.dummies)
	except Exception:
		prob.config.dummies = tuple(prob.config.dummies)
	try:
		len(prob.config.tests)
	except Exception:
		prob.config.tests = tuple(prob.config.tests)
	
	if prob.config.match == 're':
		if not prob.config.usegroups:
			prob.config.tests = prob.config.tests, None
		elif isinstance(prob.config.tests, basestring):
			prob.config.tests = prob.config.tests, 2
		parts = tuple(map(re.escape, prob.config.dummyinname.split('$')))
		probname = re.escape(prob.name) + '/' if prob.name != os.curdir else ''
		path = '%s%s(%s)' % (probname, parts[0], prob.config.dummies)
		path += r'\1'.join(parts[1:])
		prob.config.dummies = regexp(path, None)
		parts = tuple(map(re.escape, prob.config.testcaseinname.split('$')))
		path = '%s%s(%s)' % (probname, parts[0], prob.config.tests[0])
		path += r'\1'.join(parts[1:])
		prob.config.tests = regexp(path, prob.config.tests[1])
	
	if options.legacy:
		prob.config.usegroups = False
		newtests = []
		for name in prob.config.tests:
			# Same here; we'll need to iterate over them twice
			try:
				l = len(name)
			except Exception:
				try:
					name = tuple(name)
				except TypeError:
					name = (name,)
				l = len(name)
			if l > 1:
				prob.config.usegroups = True
			newtests.append(name)
		if prob.config.usegroups:
			prob.config.tests = newtests
		del newtests
	
	# Even if they have duplicate test identifiers, we must honour sequence pointmaps
	if isinstance(prob.config.pointmap, dict):
		def getpoints(i, j, k=None):
			try:
				return prob.config.pointmap[i]
			except KeyError:
				try:
					return prob.config.pointmap[None]
				except KeyError:
					return prob.config.maxexitcode or 1
	elif prob.config.usegroups:
		def getpoints(i, j, k):
			try:
				return prob.config.pointmap[k][j]
			except LookupError:
				return prob.config.maxexitcode or 1
	else:
		def getpoints(i, j):
			try:
				return prob.config.pointmap[j]
			except LookupError:
				return prob.config.maxexitcode or 1
	
	# First get prob.cache.padoutput right,
	# then yield the actual test cases
	for i in prob.config.dummies:
		s = 'sample ' + str(i).zfill(prob.config.paddummies)
		prob.cache.padoutput = max(prob.cache.padoutput, len(s))
	if prob.config.usegroups:
		if not isinstance(prob.config.groupweight, dict):
			prob.config.groupweight = dict(enumerate(prob.config.groupweight))
		for group in prob.config.tests:
			for i in group:
				s = str(i).zfill(prob.config.padtests)
				prob.cache.padoutput = max(prob.cache.padoutput, len(s))
		if prob.config.dummies:
			yield DummyTestGroup()
			for i in prob.config.dummies:
				s = str(i).zfill(prob.config.paddummies)
				if (yield _types[prob.config.kind](prob, s, True, 0)):
					yield
			yield test_context_end
		for k, group in enumerate(prob.config.tests):
			if not group:
				continue
			yield TestGroup(prob.config.groupweight.get(k, prob.config.groupweight.get(None)))
			case_type = _types[prob.config.kind]
			for j, i in enumerate(group):
				s = str(i).zfill(prob.config.padtests)
				if not (yield case_type(prob, s, False, getpoints(i, j, k))):
					if options.skim:
						case_type = testcases.SkippedTestCase
				else:
					yield
			yield test_context_end
	else:
		for i in prob.config.tests:
			s = str(i).zfill(prob.config.padtests)
			prob.cache.padoutput = max(prob.cache.padoutput, len(s))
		for i in prob.config.dummies:
			s = str(i).zfill(prob.config.paddummies)
			if (yield _types[prob.config.kind](prob, s, True, 0)):
				yield
		for j, i in enumerate(prob.config.tests):
			s = str(i).zfill(prob.config.padtests)
			if (yield _types[prob.config.kind](prob, s, False, getpoints(i, j))):
				yield

def regexp(pattern, group):
	reobj = re.compile(pattern, re.UNICODE)
	if not group:
		ids = []
		for f in files.regexp(pattern):
			ids.append(re.match(reobj, f.virtual_path).group(1))
		return natsorted(ids)
	else:
		ids = {}
		for f in files.regexp(pattern):
			m = re.match(reobj, f.virtual_path)
			g = m.group(group)
			ids.setdefault(g, [])
			ids[g].append(m.group(1))
		for g in ids:
			ids[g] = natsorted(ids[g])
		return [ids[g] for g in natsorted(keys(ids))]

def natsorted(l):
	return sorted(l, key=lambda s: [int(t) if t.isdigit() else t for t in re.split('(\d+)', s)])