4
4
"""
5
5
6
6
import re
7
- from _csv import Error , writer , reader , \
7
+ import types
8
+ from _csv import Error , __version__ , writer , reader , register_dialect , \
9
+ unregister_dialect , get_dialect , list_dialects , \
10
+ field_size_limit , \
8
11
QUOTE_MINIMAL , QUOTE_ALL , QUOTE_NONNUMERIC , QUOTE_NONE , \
12
+ QUOTE_STRINGS , QUOTE_NOTNULL , \
9
13
__doc__
14
+ from _csv import Dialect as _Dialect
10
15
11
- from collections import OrderedDict
12
16
from io import StringIO
13
17
14
18
__all__ = ["QUOTE_MINIMAL" , "QUOTE_ALL" , "QUOTE_NONNUMERIC" , "QUOTE_NONE" ,
19
+ "QUOTE_STRINGS" , "QUOTE_NOTNULL" ,
15
20
"Error" , "Dialect" , "__doc__" , "excel" , "excel_tab" ,
16
21
"field_size_limit" , "reader" , "writer" ,
17
- "Sniffer" ,
22
+ "register_dialect" , "get_dialect" , "list_dialects" , " Sniffer" ,
18
23
"unregister_dialect" , "__version__" , "DictReader" , "DictWriter" ,
19
24
"unix_dialect" ]
20
25
@@ -57,10 +62,12 @@ class excel(Dialect):
57
62
skipinitialspace = False
58
63
lineterminator = '\r \n '
59
64
quoting = QUOTE_MINIMAL
65
+ register_dialect ("excel" , excel )
60
66
61
67
class excel_tab (excel ):
62
68
"""Describe the usual properties of Excel-generated TAB-delimited files."""
63
69
delimiter = '\t '
70
+ register_dialect ("excel-tab" , excel_tab )
64
71
65
72
class unix_dialect (Dialect ):
66
73
"""Describe the usual properties of Unix-generated CSV files."""
@@ -70,11 +77,14 @@ class unix_dialect(Dialect):
70
77
skipinitialspace = False
71
78
lineterminator = '\n '
72
79
quoting = QUOTE_ALL
80
+ register_dialect ("unix" , unix_dialect )
73
81
74
82
75
83
class DictReader :
76
84
def __init__ (self , f , fieldnames = None , restkey = None , restval = None ,
77
85
dialect = "excel" , * args , ** kwds ):
86
+ if fieldnames is not None and iter (fieldnames ) is fieldnames :
87
+ fieldnames = list (fieldnames )
78
88
self ._fieldnames = fieldnames # list of keys for the dict
79
89
self .restkey = restkey # key to catch long rows
80
90
self .restval = restval # default value for short rows
@@ -111,7 +121,7 @@ def __next__(self):
111
121
# values
112
122
while row == []:
113
123
row = next (self .reader )
114
- d = OrderedDict (zip (self .fieldnames , row ))
124
+ d = dict (zip (self .fieldnames , row ))
115
125
lf = len (self .fieldnames )
116
126
lr = len (row )
117
127
if lf < lr :
@@ -121,21 +131,26 @@ def __next__(self):
121
131
d [key ] = self .restval
122
132
return d
123
133
134
+ __class_getitem__ = classmethod (types .GenericAlias )
135
+
124
136
125
137
class DictWriter :
126
138
def __init__ (self , f , fieldnames , restval = "" , extrasaction = "raise" ,
127
139
dialect = "excel" , * args , ** kwds ):
140
+ if fieldnames is not None and iter (fieldnames ) is fieldnames :
141
+ fieldnames = list (fieldnames )
128
142
self .fieldnames = fieldnames # list of keys for the dict
129
143
self .restval = restval # for writing short dicts
130
- if extrasaction .lower () not in ("raise" , "ignore" ):
144
+ extrasaction = extrasaction .lower ()
145
+ if extrasaction not in ("raise" , "ignore" ):
131
146
raise ValueError ("extrasaction (%s) must be 'raise' or 'ignore'"
132
147
% extrasaction )
133
148
self .extrasaction = extrasaction
134
149
self .writer = writer (f , dialect , * args , ** kwds )
135
150
136
151
def writeheader (self ):
137
152
header = dict (zip (self .fieldnames , self .fieldnames ))
138
- self .writerow (header )
153
+ return self .writerow (header )
139
154
140
155
def _dict_to_list (self , rowdict ):
141
156
if self .extrasaction == "raise" :
@@ -151,11 +166,8 @@ def writerow(self, rowdict):
151
166
def writerows (self , rowdicts ):
152
167
return self .writer .writerows (map (self ._dict_to_list , rowdicts ))
153
168
154
- # Guard Sniffer's type checking against builds that exclude complex()
155
- try :
156
- complex
157
- except NameError :
158
- complex = float
169
+ __class_getitem__ = classmethod (types .GenericAlias )
170
+
159
171
160
172
class Sniffer :
161
173
'''
@@ -404,14 +416,10 @@ def has_header(self, sample):
404
416
continue # skip rows that have irregular number of columns
405
417
406
418
for col in list (columnTypes .keys ()):
407
-
408
- for thisType in [int , float , complex ]:
409
- try :
410
- thisType (row [col ])
411
- break
412
- except (ValueError , OverflowError ):
413
- pass
414
- else :
419
+ thisType = complex
420
+ try :
421
+ thisType (row [col ])
422
+ except (ValueError , OverflowError ):
415
423
# fallback to length of string
416
424
thisType = len (row [col ])
417
425
@@ -427,7 +435,7 @@ def has_header(self, sample):
427
435
# on whether it's a header
428
436
hasHeader = 0
429
437
for col , colType in columnTypes .items ():
430
- if type (colType ) == type ( 0 ): # it's a length
438
+ if isinstance (colType , int ): # it's a length
431
439
if len (header [col ]) != colType :
432
440
hasHeader += 1
433
441
else :
0 commit comments