diff zipfiles/zipfile31.py @ 170:b993d9257400

Updated zipfiles
author Oleg Oshmyan <chortos@inbox.lv>
date Thu, 16 Jun 2011 01:24:10 +0100
parents 45d4a9dc707b
children
line wrap: on
line diff
--- a/zipfiles/zipfile31.py	Wed Jun 15 14:34:48 2011 +0100
+++ b/zipfiles/zipfile31.py	Thu Jun 16 01:24:10 2011 +0100
@@ -898,8 +898,12 @@
 
     def setpassword(self, pwd):
         """Set default password for encrypted files."""
-        assert isinstance(pwd, bytes)
-        self.pwd = pwd
+        if pwd and not isinstance(pwd, bytes):
+            raise TypeError("pwd: expected bytes, got %s" % type(pwd))
+        if pwd:
+            self.pwd = pwd
+        else:
+            self.pwd = None
 
     def read(self, name, pwd=None):
         """Return file bytes (as a string) for name."""
@@ -909,6 +913,8 @@
         """Return file-like object for 'name'."""
         if mode not in ("r", "U", "rU"):
             raise RuntimeError('open() requires mode "r", "U", or "rU"')
+        if pwd and not isinstance(pwd, bytes):
+            raise TypeError("pwd: expected bytes, got %s" % type(pwd))
         if not self.fp:
             raise RuntimeError(
                   "Attempt to read ZIP archive that was already closed")
@@ -940,7 +946,13 @@
         if fheader[_FH_EXTRA_FIELD_LENGTH]:
             zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
 
-        if fname != zinfo.orig_filename.encode("utf-8"):
+        if zinfo.flag_bits & 0x800:
+            # UTF-8 filename
+            fname_str = fname.decode("utf-8")
+        else:
+            fname_str = fname.decode("cp437")
+
+        if fname_str != zinfo.orig_filename:
             raise BadZipfile(
                   'File name in directory %r and header %r differ.'
                   % (zinfo.orig_filename, fname))
@@ -961,8 +973,8 @@
             #  completely random, while the 12th contains the MSB of the CRC,
             #  or the MSB of the file time depending on the header type
             #  and is used to check the correctness of the password.
-            bytes = zef_file.read(12)
-            h = list(map(zd, bytes[0:12]))
+            header = zef_file.read(12)
+            h = list(map(zd, header[0:12]))
             if zinfo.flag_bits & 0x8:
                 # compare against the file type from extended local headers
                 check_byte = (zinfo._raw_time >> 8) & 0xff