changeset 195:c2490e39fd70

Revamped the implementation of files.Archive subclasses They now normalize and sanitize all paths and provide a listdir method. TarArchive also ignores all files that are not regular files or directories.
author Oleg Oshmyan <chortos@inbox.lv>
date Sun, 14 Aug 2011 01:02:10 +0300 (2011-08-13)
parents 8c30a2c8a09e
children 67088c1765b4
files upreckon/files.py
diffstat 1 files changed, 95 insertions(+), 45 deletions(-) [+]
line wrap: on
line diff
--- a/upreckon/files.py	Fri Aug 12 17:51:33 2011 +0300
+++ b/upreckon/files.py	Sun Aug 14 01:02:10 2011 +0300
@@ -16,7 +16,7 @@
 formats = {}
 
 class Archive(object):
-	__slots__ = 'file'
+	__slots__ = ()
 	
 	if ABCMeta:
 		__metaclass__ = ABCMeta
@@ -43,6 +43,15 @@
 	
 	@abstractmethod
 	def extract(self, name, target): raise NotImplementedError
+	
+	@abstractmethod
+	def open(self, name): raise NotImplementedError
+	
+	@abstractmethod
+	def exists(self, name): raise NotImplementedError
+	
+	@abstractmethod
+	def listdir(self, name): raise NotImplementedError
 
 try:
 	import tarfile
@@ -50,45 +59,66 @@
 	TarArchive = None
 else:
 	class TarArchive(Archive):
-		__slots__ = '_namelist'
+		__slots__ = '_tarfile', '_files', '_dirs', '_names'
 		
 		def __init__(self, path):
-			self.file = tarfile.open(path)
+			self._tarfile = tarfile.open(path)
+			files, dirs = {}, set()
+			for member in self._tarfile.getmembers():
+				cutname = posixpath.normpath(member.name).lstrip('/')
+				while cutname.startswith('../'):
+					cutname = cutname[3:]
+				if cutname in ('.', '..'):
+					continue
+				if member.isfile():
+					files[cutname] = member
+					cutname = posixpath.dirname(cutname)
+				elif not member.isdir():
+					continue
+				while cutname:
+					dirs.add(cutname)
+					cutname = posixpath.dirname(cutname)
+			self._files = files
+			self._dirs = frozenset(dirs)
+			self._names = self._dirs | frozenset(files)
 		
 		def extract(self, name, target):
-			member = self.file.getmember(name)
+			member = self._files[posixpath.normpath(name)]
 			member.name = target
-			self.file.extract(member)
+			self._tarfile.extract(member)
 		
 		def open(self, name):
-			return self.file.extractfile(name)
+			name = posixpath.normpath(name)
+			return self._tarfile.extractfile(self._files[name])
+		
+		def exists(self, name):
+			return posixpath.normpath(name) in self._names
 		
-		def exists(self, queried_name):
-			if not hasattr(self, '_namelist'):
-				names = set()
-				for name in self.file.getnames():
-					cutname = name
-					while cutname:
-						names.add(cutname)
-						cutname = cutname.rpartition('/')[0]
-				self._namelist = frozenset(names)
-			return queried_name in self._namelist
+		def listdir(self, name):
+			normname = posixpath.normpath(name)
+			if normname not in self._dirs:
+				raise KeyError('No such directory: %r' % name)
+			normname += '/'
+			len_normname = len(normname)
+			return [fname for fname in self._names
+			              if fname.startswith(normname) and
+			                 fname.find('/', len_normname) == -1]
 		
 		def __enter__(self):
-			if hasattr(self.file, '__enter__'):
-				self.file.__enter__()
+			if hasattr(self._tarfile, '__enter__'):
+				self._tarfile.__enter__()
 			return self
 		
 		def __exit__(self, exc_type, exc_value, traceback):
-			if hasattr(self.file, '__exit__'):
-				return self.file.__exit__(exc_type, exc_value, traceback)
+			if hasattr(self._tarfile, '__exit__'):
+				return self._tarfile.__exit__(exc_type, exc_value, traceback)
 			elif exc_type is None:
-				self.file.close()
+				self._tarfile.close()
 			else:
 				# This code was shamelessly copied from tarfile.py of Python 2.7
-				if not self.file._extfileobj:
-					self.file.fileobj.close()
-				self.file.closed = True
+				if not self._tarfile._extfileobj:
+					self._tarfile.fileobj.close()
+				self._tarfile.closed = True
 	
 	formats['tar'] = formats['tgz'] = formats['tar.gz'] = formats['tbz2'] = formats['tar.bz2'] = TarArchive
 
@@ -98,47 +128,67 @@
 	ZipArchive = None
 else:
 	class ZipArchive(Archive):
-		__slots__ = '_namelist'
+		__slots__ = '_zipfile', '_files', '_dirs', '_names'
 		
 		def __init__(self, path):
-			self.file = zipfile.ZipFile(path)
+			self._zipfile = zipfile.ZipFile(path)
+			files, dirs = {}, set()
+			for member in self._zipfile.infolist():
+				cutname = posixpath.normpath(member.filename).lstrip('/')
+				while cutname.startswith('../'):
+					cutname = cutname[3:]
+				if cutname in ('.', '..'):
+					continue
+				if not member.filename.endswith('/'):
+					files[cutname] = member
+					cutname = posixpath.dirname(cutname)
+				while cutname:
+					dirs.add(cutname)
+					cutname = posixpath.dirname(cutname)
+			self._files = files
+			self._dirs = frozenset(dirs)
+			self._names = self._dirs | frozenset(files)
 		
 		def extract(self, name, target):
-			member = self.file.getinfo(name)
+			member = self._files[posixpath.normpath(name)]
 			# FIXME: 2.5 lacks ZipFile.extract
 			if os.path.isabs(target):
 				# To my knowledge, this is as portable as it gets
 				path = os.path.join(os.path.splitdrive(target)[0], os.path.sep)
 				member.filename = os.path.relpath(target, path)
-				self.file.extract(member, path)
+				self._zipfile.extract(member, path)
 			else:
 				member.filename = os.path.relpath(target)
-				self.file.extract(member)
+				self._zipfile.extract(member)
 		
 		def open(self, name):
-			return self.file.open(name, 'r')
+			name = posixpath.normpath(name)
+			# FIXME: 2.5 lacks ZipFile.open
+			return self._zipfile.open(self._files[name])
+		
+		def exists(self, name):
+			return posixpath.normpath(name) in self._names
 		
-		def exists(self, queried_name):
-			if not hasattr(self, '_namelist'):
-				names = set()
-				for name in self.file.namelist():
-					cutname = name
-					while cutname:
-						names.add(cutname)
-						cutname = cutname.rpartition('/')[0]
-				self._namelist = frozenset(names)
-			return queried_name in self._namelist
+		def listdir(self, name):
+			normname = posixpath.normpath(name)
+			if normname not in self._dirs:
+				raise KeyError('No such directory: %r' % name)
+			normname += '/'
+			len_normname = len(normname)
+			return [fname for fname in self._names
+			              if fname.startswith(normname) and
+			                 fname.find('/', len_normname) == -1]
 		
 		def __enter__(self):
-			if hasattr(self.file, '__enter__'):
-				self.file.__enter__()
+			if hasattr(self._zipfile, '__enter__'):
+				self._zipfile.__enter__()
 			return self
 		
 		def __exit__(self, exc_type, exc_value, traceback):
-			if hasattr(self.file, '__exit__'):
-				return self.file.__exit__(exc_type, exc_value, traceback)
+			if hasattr(self._zipfile, '__exit__'):
+				return self._zipfile.__exit__(exc_type, exc_value, traceback)
 			else:
-				return self.file.close()
+				return self._zipfile.close()
 	
 	formats['zip'] = ZipArchive