view upreckon/files.py @ 195:c2490e39fd70

Revamped the implementation of files.Archive subclasses They now normalize and sanitize all paths and provide a listdir method. TarArchive also ignores all files that are not regular files or directories.
author Oleg Oshmyan <chortos@inbox.lv>
date Sun, 14 Aug 2011 01:02:10 +0300
parents 8c30a2c8a09e
children 67088c1765b4
line wrap: on
line source

# Copyright (c) 2010-2011 Chortos-2 <chortos@inbox.lv>

"""File access routines and classes with support for archives."""

from __future__ import division, with_statement

from .compat import *
import contextlib, itertools, os, posixpath, re, shutil, sys

# You don't need to know about anything else.
__all__ = 'File', 'regexp'

# In these two variables, use full stops no matter what os.extsep is;
# all full stops will be converted to os.extsep on the fly
archives = 'tests.tar', 'tests.zip', 'tests.tgz', 'tests.tar.gz', 'tests.tbz2', 'tests.tar.bz2'
formats = {}

class Archive(object):
	__slots__ = ()
	
	if ABCMeta:
		__metaclass__ = ABCMeta
	
	def __new__(cls, path):
		"""
		Create a new instance of the archive class corresponding
		to the file name in the given path.
		"""
		if cls is not Archive:
			return object.__new__(cls)
		else:
			# Do this by hand rather than through os.path.splitext
			# because we support multi-dotted file name extensions
			ext = path.partition(os.path.extsep)[2]
			while ext:
				if ext in formats:
					return formats[ext](path)
				ext = ext.partition(os.path.extsep)[2]
			raise LookupError("unsupported archive file name extension in file name '%s'" % filename)
	
	@abstractmethod
	def __init__(self, path): raise NotImplementedError
	
	@abstractmethod
	def extract(self, name, target): raise NotImplementedError
	
	@abstractmethod
	def open(self, name): raise NotImplementedError
	
	@abstractmethod
	def exists(self, name): raise NotImplementedError
	
	@abstractmethod
	def listdir(self, name): raise NotImplementedError

try:
	import tarfile
except ImportError:
	TarArchive = None
else:
	class TarArchive(Archive):
		__slots__ = '_tarfile', '_files', '_dirs', '_names'
		
		def __init__(self, path):
			self._tarfile = tarfile.open(path)
			files, dirs = {}, set()
			for member in self._tarfile.getmembers():
				cutname = posixpath.normpath(member.name).lstrip('/')
				while cutname.startswith('../'):
					cutname = cutname[3:]
				if cutname in ('.', '..'):
					continue
				if member.isfile():
					files[cutname] = member
					cutname = posixpath.dirname(cutname)
				elif not member.isdir():
					continue
				while cutname:
					dirs.add(cutname)
					cutname = posixpath.dirname(cutname)
			self._files = files
			self._dirs = frozenset(dirs)
			self._names = self._dirs | frozenset(files)
		
		def extract(self, name, target):
			member = self._files[posixpath.normpath(name)]
			member.name = target
			self._tarfile.extract(member)
		
		def open(self, name):
			name = posixpath.normpath(name)
			return self._tarfile.extractfile(self._files[name])
		
		def exists(self, name):
			return posixpath.normpath(name) in self._names
		
		def listdir(self, name):
			normname = posixpath.normpath(name)
			if normname not in self._dirs:
				raise KeyError('No such directory: %r' % name)
			normname += '/'
			len_normname = len(normname)
			return [fname for fname in self._names
			              if fname.startswith(normname) and
			                 fname.find('/', len_normname) == -1]
		
		def __enter__(self):
			if hasattr(self._tarfile, '__enter__'):
				self._tarfile.__enter__()
			return self
		
		def __exit__(self, exc_type, exc_value, traceback):
			if hasattr(self._tarfile, '__exit__'):
				return self._tarfile.__exit__(exc_type, exc_value, traceback)
			elif exc_type is None:
				self._tarfile.close()
			else:
				# This code was shamelessly copied from tarfile.py of Python 2.7
				if not self._tarfile._extfileobj:
					self._tarfile.fileobj.close()
				self._tarfile.closed = True
	
	formats['tar'] = formats['tgz'] = formats['tar.gz'] = formats['tbz2'] = formats['tar.bz2'] = TarArchive

try:
	import zipfile
except ImportError:
	ZipArchive = None
else:
	class ZipArchive(Archive):
		__slots__ = '_zipfile', '_files', '_dirs', '_names'
		
		def __init__(self, path):
			self._zipfile = zipfile.ZipFile(path)
			files, dirs = {}, set()
			for member in self._zipfile.infolist():
				cutname = posixpath.normpath(member.filename).lstrip('/')
				while cutname.startswith('../'):
					cutname = cutname[3:]
				if cutname in ('.', '..'):
					continue
				if not member.filename.endswith('/'):
					files[cutname] = member
					cutname = posixpath.dirname(cutname)
				while cutname:
					dirs.add(cutname)
					cutname = posixpath.dirname(cutname)
			self._files = files
			self._dirs = frozenset(dirs)
			self._names = self._dirs | frozenset(files)
		
		def extract(self, name, target):
			member = self._files[posixpath.normpath(name)]
			# FIXME: 2.5 lacks ZipFile.extract
			if os.path.isabs(target):
				# To my knowledge, this is as portable as it gets
				path = os.path.join(os.path.splitdrive(target)[0], os.path.sep)
				member.filename = os.path.relpath(target, path)
				self._zipfile.extract(member, path)
			else:
				member.filename = os.path.relpath(target)
				self._zipfile.extract(member)
		
		def open(self, name):
			name = posixpath.normpath(name)
			# FIXME: 2.5 lacks ZipFile.open
			return self._zipfile.open(self._files[name])
		
		def exists(self, name):
			return posixpath.normpath(name) in self._names
		
		def listdir(self, name):
			normname = posixpath.normpath(name)
			if normname not in self._dirs:
				raise KeyError('No such directory: %r' % name)
			normname += '/'
			len_normname = len(normname)
			return [fname for fname in self._names
			              if fname.startswith(normname) and
			                 fname.find('/', len_normname) == -1]
		
		def __enter__(self):
			if hasattr(self._zipfile, '__enter__'):
				self._zipfile.__enter__()
			return self
		
		def __exit__(self, exc_type, exc_value, traceback):
			if hasattr(self._zipfile, '__exit__'):
				return self._zipfile.__exit__(exc_type, exc_value, traceback)
			else:
				return self._zipfile.close()
	
	formats['zip'] = ZipArchive

# Remove unsupported archive formats and replace full stops
# with the platform-dependent file name extension separator
def issupported(filename, formats=formats):
	ext = filename.partition('.')[2]
	while ext:
		if ext in formats: return True
		ext = ext.partition('.')[2]
	return False
archives = [filename.replace('.', os.path.extsep) for filename in filter(issupported, archives)]
formats = dict((item[0].replace('.', os.path.extsep), item[1]) for item in items(formats))

open_archives = {}

def open_archive(path):
	if path in open_archives:
		return open_archives[path]
	else:
		open_archives[path] = archive = Archive(path)
		return archive

class File(object):
	__slots__ = 'virtual_path', 'real_path', 'full_real_path', 'archive'
	
	def __init__(self, virtpath, allow_root=False, msg='test data'):
		self.virtual_path = virtpath
		self.archive = None
		if not self.realize_path('', tuple(comp.replace('.', os.path.extsep) for comp in virtpath.split('/')), allow_root):
			raise IOError("%s file '%s' could not be found" % (msg, virtpath))
	
	def realize_path(self, root, virtpath, allow_root=False, hastests=False):
		if root and not os.path.exists(root):
			return False
		if len(virtpath) > 1:
			if self.realize_path(os.path.join(root, virtpath[0]), virtpath[1:], allow_root, hastests):
				return True
			elif not hastests:
				if self.realize_path(os.path.join(root, 'tests'), virtpath, allow_root, True):
					return True
				for archive in archives:
					path = os.path.join(root, archive)
					if os.path.exists(path):
						if self.realize_path_archive(open_archive(path), '', virtpath, path):
							return True
			if self.realize_path(root, virtpath[1:], allow_root, hastests):
				return True
		else:
			if not hastests:
				path = os.path.join(root, 'tests', virtpath[0])
				if os.path.exists(path):
					self.full_real_path = self.real_path = path
					return True
				for archive in archives:
					path = os.path.join(root, archive)
					if os.path.exists(path):
						if self.realize_path_archive(open_archive(path), '', virtpath, path):
							return True
			if hastests or allow_root:
				path = os.path.join(root, virtpath[0])
				if os.path.exists(path):
					self.full_real_path = self.real_path = path
					return True
		return False
	
	def realize_path_archive(self, archive, root, virtpath, archpath, hastests=False):
		if root and not archive.exists(root):
			return False
		path = posixpath.join(root, virtpath[0])
		if len(virtpath) > 1:
			if self.realize_path_archive(archive, path, virtpath[1:], archpath):
				return True
			elif self.realize_path_archive(archive, root, virtpath[1:], archpath):
				return True
		else:
			if archive.exists(path):
				self.archive = archive
				self.real_path = path
				self.full_real_path = os.path.join(archpath, *path.split('/'))
				return True
		if not hastests:
			if self.realize_path_archive(archive, posixpath.join(root, 'tests'), virtpath, archpath, True):
				return True
		return False
	
	def open(self):
		if self.archive:
			file = self.archive.open(self.real_path)
			if hasattr(file, '__exit__'):
				return file
			else:
				return contextlib.closing(file)
		else:
			return open(self.real_path, 'rb')
	
	def copy(self, target):
		if self.archive:
			self.archive.extract(self.real_path, target)
		else:
			shutil.copy(self.real_path, target)

def regexp(pattern, hastests=False, droplast=False):
	# FIXME: add test archives
	if not pattern:
		yield os.curdir, ''
		return
	dirname, basename = posixpath.split(pattern)
	if hastests:
		dirnames = regexp(dirname, True)
	else:
		dirnames = itertools.chain(regexp(dirname), regexp(posixpath.join(dirname, 'tests'), True, True))
	reobj = re.compile(pattern + '$', re.UNICODE)
	for dirname, vdirname in dirnames:
		try:
			names = os.listdir(dirname)
		except OSError:
			continue
		for name in names:
			path = os.path.join(dirname, name)
			vpath = posixpath.join(vdirname, name)
			if re.match(reobj, vpath):
				yield path, vpath if not droplast else vdirname