view files.py @ 132:cdd0f970d112

Fixed several small bugs in the files module
author Oleg Oshmyan <chortos@inbox.lv>
date Thu, 19 May 2011 02:55:36 +0100
parents 62a96d51bf94
children a9d2aa6810c7
line wrap: on
line source

# Copyright (c) 2010 Chortos-2 <chortos@inbox.lv>

"""File access routines and classes with support for archives."""

from __future__ import division, with_statement

from compat import *
import contextlib, os, shutil, sys

# You don't need to know about anything else.
__all__ = 'File',

# In these two variables, use full stops no matter what os.extsep is;
# all full stops will be converted to os.extsep on the fly
archives = 'tests.tar', 'tests.zip', 'tests.tgz', 'tests.tar.gz', 'tests.tbz2', 'tests.tar.bz2'
formats = {}

class Archive(object):
	__slots__ = 'file'
	
	if ABCMeta:
		__metaclass__ = ABCMeta
	
	def __new__(cls, path):
		"""
		Create a new instance of the archive class corresponding
		to the file name in the given path.
		"""
		if cls is not Archive:
			return object.__new__(cls)
		else:
			# Do this by hand rather than through os.path.splitext
			# because we support multi-dotted file name extensions
			ext = path.partition(os.path.extsep)[2]
			while ext:
				if ext in formats:
					return formats[ext](path)
				ext = ext.partition(os.path.extsep)[2]
			raise LookupError("unsupported archive file name extension in file name '%s'" % filename)
	
	@abstractmethod
	def __init__(self, path): raise NotImplementedError
	
	@abstractmethod
	def extract(self, name, target): raise NotImplementedError
	
	def __del__(self):
		try:
			del self.file
		except NameError:
			pass

try:
	import tarfile
except ImportError:
	TarArchive = None
else:
	class TarArchive(Archive):
		__slots__ = '_namelist'
		
		def __init__(self, path):
			self.file = tarfile.open(path)
		
		def extract(self, name, target):
			member = self.file.getmember(name)
			member.name = target
			self.file.extract(member)
		
		# TODO: somehow automagically emulate universal line break support
		def open(self, name):
			return self.file.extractfile(name)
		
		def exists(self, queried_name):
			if not hasattr(self, '_namelist'):
				names = set()
				for name in self.file.getnames():
					cutname = name
					while cutname:
						names.add(cutname)
						cutname = cutname.rpartition('/')[0]
				self._namelist = frozenset(names)
			return queried_name in self._namelist
		
		def __enter__(self):
			if hasattr(self.file, '__enter__'):
				self.file.__enter__()
			return self
		
		def __exit__(self, exc_type, exc_value, traceback):
			if hasattr(self.file, '__exit__'):
				return self.file.__exit__(exc_type, exc_value, traceback)
			elif exc_type is None:
				self.file.close()
			else:
				# This code was shamelessly copied from tarfile.py of Python 2.7
				if not self.file._extfileobj:
					self.file.fileobj.close()
				self.file.closed = True
	
	formats['tar'] = formats['tgz'] = formats['tar.gz'] = formats['tbz2'] = formats['tar.bz2'] = TarArchive

try:
	import zipfile
except ImportError:
	ZipArchive = None
else:
	class ZipArchive(Archive):
		__slots__ = '_namelist'
		
		def __init__(self, path):
			self.file = zipfile.ZipFile(path)
		
		def extract(self, name, target):
			member = self.file.getinfo(name)
			# FIXME: 2.5 lacks ZipFile.extract
			if os.path.isabs(target):
				# To my knowledge, this is as portable as it gets
				path = os.path.join(os.path.splitdrive(target)[0], os.path.sep)
				member.filename = os.path.relpath(target, path)
				self.file.extract(member, path)
			else:
				member.filename = os.path.relpath(target)
				self.file.extract(member)
		
		def open(self, name):
			return self.file.open(name, 'rU')
		
		def exists(self, queried_name):
			if not hasattr(self, '_namelist'):
				names = set()
				for name in self.file.namelist():
					cutname = name
					while cutname:
						names.add(cutname)
						cutname = cutname.rpartition('/')[0]
				self._namelist = frozenset(names)
			return queried_name in self._namelist
		
		def __enter__(self):
			if hasattr(self.file, '__enter__'):
				self.file.__enter__()
			return self
		
		def __exit__(self, exc_type, exc_value, traceback):
			if hasattr(self.file, '__exit__'):
				return self.file.__exit__(exc_type, exc_value, traceback)
			else:
				return self.file.close()
	
	formats['zip'] = ZipArchive

# Remove unsupported archive formats and replace full stops
# with the platform-dependent file name extension separator
def issupported(filename, formats=formats):
	ext = filename.partition('.')[2]
	while ext:
		if ext in formats: return True
		ext = ext.partition('.')[2]
	return False
archives = [filename.replace('.', os.path.extsep) for filename in filter(issupported, archives)]
formats = dict((item[0].replace('.', os.path.extsep), item[1]) for item in items(formats))

open_archives = {}

def open_archive(path):
	if path in open_archives:
		return open_archives[path]
	else:
		open_archives[path] = archive = Archive(path)
		return archive

class File(object):
	__slots__ = 'virtual_path', 'real_path', 'full_real_path', 'archive'
	
	def __init__(self, virtpath, allow_root=False, msg='test data'):
		self.virtual_path = virtpath
		self.archive = None
		if not self.realize_path('', tuple(comp.replace('.', os.path.extsep) for comp in virtpath.split('/')), allow_root):
			raise IOError("%s file '%s' could not be found" % (msg, virtpath))
	
	def realize_path(self, root, virtpath, allow_root=False, hastests=False):
		if root and not os.path.exists(root):
			return False
		if len(virtpath) > 1:
			if self.realize_path(os.path.join(root, virtpath[0]), virtpath[1:], allow_root, hastests):
				return True
			elif not hastests:
				if self.realize_path(os.path.join(root, 'tests'), virtpath, allow_root, True):
					return True
				for archive in archives:
					path = os.path.join(root, archive)
					if os.path.exists(path):
						if self.realize_path_archive(open_archive(path), '', virtpath, path):
							return True
			if self.realize_path(root, virtpath[1:], allow_root, hastests):
				return True
		else:
			if not hastests:
				path = os.path.join(root, 'tests', virtpath[0])
				if os.path.exists(path):
					self.full_real_path = self.real_path = path
					return True
				for archive in archives:
					path = os.path.join(root, archive)
					if os.path.exists(path):
						if self.realize_path_archive(open_archive(path), '', virtpath, path):
							return True
			if hastests or allow_root:
				path = os.path.join(root, virtpath[0])
				if os.path.exists(path):
					self.full_real_path = self.real_path = path
					return True
		return False
	
	def realize_path_archive(self, archive, root, virtpath, archpath):
		if root and not archive.exists(root):
			return False
		if root: path = ''.join((root, '/', virtpath[0]))
		else: path = virtpath[0]
		if len(virtpath) > 1:
			if self.realize_path_archive(archive, path, virtpath[1:], archpath):
				return True
			elif self.realize_path_archive(archive, root, virtpath[1:], archpath):
				return True
		else:
			if archive.exists(path):
				self.archive = archive
				self.real_path = path
				self.full_real_path = os.path.join(archpath, *path.split('/'))
				return True
		return False
	
	def open(self):
		if self.archive:
			file = self.archive.open(self.real_path)
			if hasattr(file, '__exit__'):
				return file
			else:
				return contextlib.closing(file)
		else:
			return open(self.real_path, 'rU')
	
	def copy(self, target):
		if self.archive:
			self.archive.extract(self.real_path, target)
		else:
			shutil.copy(self.real_path, target)