Skip to content

Commit 4db5586

Browse files
committed
Make _ensure_file_obj a method of a generic base class. Passes mypy and pytest
1 parent 70896ce commit 4db5586

1 file changed

Lines changed: 118 additions & 81 deletions

File tree

src/shapefile.py

Lines changed: 118 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,30 +2402,6 @@ def close(self) -> None:
24022402
FileProtoT = TypeVar("FileProtoT")
24032403

24042404

2405-
def _ensure_file_obj(
2406-
f: str | FileProtoT | None,
2407-
FileProto: type[FileProtoT],
2408-
exit_stack: ExitStack,
2409-
new_file_mode: Literal["rb", "wb+"] = "wb+",
2410-
ExceptionClass: type[ShapefileException] = ShapefileException,
2411-
) -> FileProtoT:
2412-
"""Safety handler to verify file-like objects"""
2413-
if not f:
2414-
raise ExceptionClass(f"No file-like object recieved. Got: {f}")
2415-
if isinstance(f, str):
2416-
pth = os.path.split(f)[0]
2417-
if pth and not os.path.exists(pth):
2418-
os.makedirs(pth)
2419-
fp = open(f, new_file_mode)
2420-
2421-
# Only push files created here to the exit stack.
2422-
# The user must close their own file objects.
2423-
exit_stack.enter_context(fp)
2424-
return cast(FileProtoT, fp)
2425-
2426-
if isinstance(f, FileProto):
2427-
return f
2428-
raise ExceptionClass(f"Unsupported file-like object: {f}")
24292405

24302406

24312407
class _FileChecker(_HasExitStack, Generic[FileProtoT]):
@@ -2441,6 +2417,8 @@ def new_file_mode(self) -> Literal["rb", "wb+"]: ...
24412417
@abc.abstractmethod
24422418
def ext(self) -> Literal[".shp", ".shx", ".dbf"]: ...
24432419

2420+
ExceptionClass = ShapefileException
2421+
24442422
def __init__(
24452423
self,
24462424
file: str | PathLike[Any] | FileProtoT,
@@ -2464,13 +2442,57 @@ def __init__(
24642442
self.encoding = encoding
24652443
self.encodingErrors = encodingErrors
24662444

2445+
@functools.cached_property
2446+
def file(self) -> FileProtoT:
2447+
return self._ensure_file_obj()
2448+
# f=self._file,
2449+
# FileProto=self.FileProto,
2450+
# exit_stack=self.exit_stack,
2451+
# new_file_mode="rb",
2452+
# ExceptionClass=dbfFileException,
2453+
#)
2454+
2455+
def _ensure_file_obj(
2456+
self,
2457+
f: str | FileProtoT | None = None,
2458+
# FileProto: type[FileProtoT],
2459+
# exit_stack: ExitStack,
2460+
# new_file_mode: Literal["rb", "wb+"] = "wb+",
2461+
# ExceptionClass: type[ShapefileException] = ShapefileException,
2462+
) -> FileProtoT:
2463+
"""Safety handler to verify file-like objects"""
2464+
2465+
f = f or self._file
2466+
exit_stack = self.exit_stack
2467+
FileProto = self.FileProto
2468+
new_file_mode = self.new_file_mode
2469+
ExceptionClass = self.ExceptionClass
2470+
2471+
2472+
if not f:
2473+
raise ExceptionClass(f"No file-like object received. Got: {f}")
2474+
if isinstance(f, str):
2475+
pth = os.path.split(f)[0]
2476+
if pth and not os.path.exists(pth):
2477+
os.makedirs(pth)
2478+
fp = open(f, new_file_mode)
2479+
2480+
# Only push files created here to the exit stack.
2481+
# The user must close their own file objects.
2482+
exit_stack.enter_context(fp)
2483+
return cast(FileProtoT, fp)
2484+
2485+
if isinstance(f, FileProto):
2486+
return f
2487+
raise ExceptionClass(f"Unsupported file-like object: {f}")
24672488

24682489
class DbfReader(_FileChecker[ReadSeekableBinStream]):
24692490
"""Reads a dbf file. You can instantiate a DbfReader without specifying a shapefile."""
24702491

24712492
FileProto = ReadSeekableBinStream
24722493
new_file_mode = "rb"
24732494
ext = ".dbf"
2495+
ExceptionClass = dbfFileException
24742496

24752497
def __init__(
24762498
self,
@@ -2486,15 +2508,15 @@ def __init__(
24862508

24872509
self._dbfHeader()
24882510

2489-
@functools.cached_property
2490-
def dbf(self) -> ReadSeekableBinStream:
2491-
return _ensure_file_obj(
2492-
f=self._file,
2493-
FileProto=self.FileProto,
2494-
exit_stack=self.exit_stack,
2495-
new_file_mode="rb",
2496-
ExceptionClass=dbfFileException,
2497-
)
2511+
# @functools.cached_property
2512+
# def dbf(self) -> ReadSeekableBinStream:
2513+
# return self._ensure_file_obj(
2514+
# # f=self._file,
2515+
# # FileProto=self.FileProto,
2516+
# # exit_stack=self.exit_stack,
2517+
# # new_file_mode="rb",
2518+
# # ExceptionClass=dbfFileException,
2519+
# )
24982520

24992521
def __len__(self) -> int:
25002522
"""Returns the number of records in the .dbf file."""
@@ -2503,17 +2525,18 @@ def __len__(self) -> int:
25032525
def _dbfHeader(self) -> None:
25042526
"""Reads a dbf header. Xbase-related code borrows heavily from ActiveState Python Cookbook Recipe 362715 by Raymond Hettinger"""
25052527

2528+
dbf = self.file
25062529
# read relevant header parts
2507-
self.dbf.seek(0)
2530+
dbf.seek(0)
25082531
self.numRecords, self.__dbfHdrLength, self._record_length = cast(
2509-
tuple[int, int, int], unpack("<xxxxLHH20x", self.dbf.read(32))
2532+
tuple[int, int, int], unpack("<xxxxLHH20x", dbf.read(32))
25102533
)
25112534

25122535
# read fields
25132536
numFields = (self.__dbfHdrLength - 33) // 32
25142537
for __field in range(numFields):
25152538
encoded_field_tuple: tuple[bytes, bytes, int, int] = unpack(
2516-
"<11sc4xBB14x", self.dbf.read(32)
2539+
"<11sc4xBB14x", dbf.read(32)
25172540
)
25182541
encoded_name, encoded_type_char, size, decimal = encoded_field_tuple
25192542

@@ -2528,7 +2551,7 @@ def _dbfHeader(self) -> None:
25282551
field_type = FIELD_TYPE_ALIASES[encoded_type_char]
25292552

25302553
self.fields.append(Field(name, field_type, size, decimal))
2531-
terminator = self.dbf.read(1)
2554+
terminator = dbf.read(1)
25322555
if terminator != b"\r":
25332556
raise ShapefileException(
25342557
"Shapefile dbf header lacks expected terminator. (likely corrupt?)"
@@ -2618,7 +2641,7 @@ def _record(
26182641
a list of field info Field namedtuples 'fieldTuples', a record name-index dict 'recLookup',
26192642
and a Struct instance 'recStruct' for unpacking these fields.
26202643
"""
2621-
f = self.dbf
2644+
f = self.file
26222645

26232646
# The only format chars in from self._record_fmt, in recStruct from _record_fields,
26242647
# are s and x (ascii encoded str and pad byte) so everything in recordContents is bytes
@@ -2712,7 +2735,7 @@ def record(self, i: int = 0, fields: list[str] | None = None) -> _Record | None:
27122735
To only read some of the fields, specify the 'fields' arg as a
27132736
list of one or more fieldnames.
27142737
"""
2715-
f = self.dbf
2738+
f = self.file
27162739

27172740
i = ensure_within_bounds(i, self.numRecords)
27182741
recSize = self._record_length
@@ -2728,9 +2751,10 @@ def records(self, fields: list[str] | None = None) -> list[_Record]:
27282751
To only read some of the fields, specify the 'fields' arg as a
27292752
list of one or more fieldnames.
27302753
"""
2754+
f = self.file
27312755

27322756
records = []
2733-
self.dbf.seek(self.__dbfHdrLength)
2757+
f.seek(self.__dbfHdrLength)
27342758
fieldTuples, recLookup, recStruct = self._record_fields(fields)
27352759

27362760
for i in range(self.numRecords):
@@ -2758,6 +2782,7 @@ def iterRecords(
27582782
start <= i < number_of_records + stop
27592783
if stop < 0).
27602784
"""
2785+
f = self.file
27612786

27622787
if not isinstance(self.numRecords, int):
27632788
raise ShapefileException(
@@ -2773,7 +2798,7 @@ def iterRecords(
27732798
elif stop < 0:
27742799
stop = range(self.numRecords)[stop]
27752800
recSize = self._record_length
2776-
self.dbf.seek(self.__dbfHdrLength + (start * recSize))
2801+
f.seek(self.__dbfHdrLength + (start * recSize))
27772802
fieldTuples, recLookup, recStruct = self._record_fields(fields)
27782803
for i in range(start, stop):
27792804
r = self._record(
@@ -2794,7 +2819,7 @@ class _NoShpSentinel:
27942819
_NO_SHP_SENTINEL = _NoShpSentinel()
27952820

27962821

2797-
class Reader(_HasExitStack):
2822+
class Reader(_FileChecker[ReadSeekableBinStream]):
27982823
"""Reads the three files of a shapefile as a unit or
27992824
separately. If one of the three files (.shp, .shx,
28002825
.dbf) is missing no exception is thrown until you try
@@ -2815,6 +2840,11 @@ class Reader(_HasExitStack):
28152840
but they can be.
28162841
"""
28172842

2843+
FileProto = ReadSeekableBinStream
2844+
new_file_mode = "rb"
2845+
ext = ".shp"
2846+
ExceptionClass = ShapefileException
2847+
28182848
def __init__(
28192849
self,
28202850
shapefile_path: str | PathLike[Any] = "",
@@ -2893,25 +2923,25 @@ def dbf_reader(self) -> DbfReader:
28932923

28942924
@functools.cached_property
28952925
def shp(self) -> ReadSeekableBinStream:
2896-
return _ensure_file_obj(
2926+
return self._ensure_file_obj(
28972927
f=self._shp,
2898-
FileProto=ReadSeekableBinStream,
2899-
exit_stack=self.exit_stack,
2900-
new_file_mode="rb",
2928+
# FileProto=ReadSeekableBinStream,
2929+
# exit_stack=self.exit_stack,
2930+
# new_file_mode="rb",
29012931
)
29022932

29032933
@functools.cached_property
29042934
def shx(self) -> ReadSeekableBinStream:
2905-
return _ensure_file_obj(
2935+
return self._ensure_file_obj(
29062936
f=self._shx,
2907-
FileProto=ReadSeekableBinStream,
2908-
exit_stack=self.exit_stack,
2909-
new_file_mode="rb",
2937+
# FileProto=ReadSeekableBinStream,
2938+
# exit_stack=self.exit_stack,
2939+
# new_file_mode="rb",
29102940
)
29112941

29122942
@property
29132943
def dbf(self) -> ReadableBinStream:
2914-
return self.dbf_reader.dbf
2944+
return self.dbf_reader.file
29152945

29162946
@property
29172947
def numRecords(self) -> int | None:
@@ -3505,27 +3535,29 @@ def __init__(
35053535
# Keep kwargs even though unused, to preserve PyShp 2.4 API
35063536
**kwargs: Any,
35073537
):
3508-
dbf = fsdecode_if_pathlike(dbf)
3509-
self._dbf: str | WriteSeekableBinStream
3510-
# Encoding
3511-
self.encoding = encoding
3512-
self.encodingErrors = encodingErrors
3513-
if isinstance(dbf, str):
3514-
self._dbf = os.path.splitext(dbf)[0] + ".dbf"
3515-
elif dbf:
3516-
self._dbf = dbf
3517-
else:
3518-
raise TypeError(
3519-
f"dbf must be set to a str, Path or file-like object. Got: {dbf}"
3520-
)
3538+
super().__init__(file=dbf, encoding=encoding, encodingErrors=encodingErrors)
3539+
3540+
# dbf = fsdecode_if_pathlike(dbf)
3541+
# self._dbf: str | WriteSeekableBinStream
3542+
# # Encoding
3543+
# self.encoding = encoding
3544+
# self.encodingErrors = encodingErrors
3545+
# if isinstance(dbf, str):
3546+
# self._dbf = os.path.splitext(dbf)[0] + ".dbf"
3547+
# elif dbf:
3548+
# self._dbf = self.file
3549+
# else:
3550+
# raise TypeError(
3551+
# f"dbf must be set to a str, Path or file-like object. Got: {dbf}"
3552+
# )
35213553

35223554
# Support not closing opened file objects passed in e.g.(handled by some
35233555
# external context manager, or the caller manually calling .close).
35243556
#
35253557
# This will only ever hold at most one context manager.
35263558
# But an ExitStack is the right tool for the job
35273559
# when the number of context manager(s) depends on user input.
3528-
self.exit_stack = ExitStack()
3560+
# self.exit_stack = ExitStack()
35293561

35303562
self.fields: list[Field] = []
35313563
self.max_num_fields = max_num_fields
@@ -3534,12 +3566,12 @@ def __init__(
35343566

35353567
@functools.cached_property
35363568
def dbf(self) -> WriteSeekableBinStream:
3537-
return _ensure_file_obj(
3538-
f=self._dbf,
3539-
FileProto=WriteSeekableBinStream,
3540-
exit_stack=self.exit_stack,
3541-
new_file_mode="wb+",
3542-
ExceptionClass=dbfFileException,
3569+
return self._ensure_file_obj(
3570+
# f=self._dbf,
3571+
# FileProto=WriteSeekableBinStream,
3572+
# exit_stack=self.exit_stack,
3573+
# new_file_mode="wb+",
3574+
# ExceptionClass=dbfFileException,
35433575
)
35443576

35453577
def close(self) -> None:
@@ -3751,10 +3783,15 @@ def __dbfRecord(self, record: list[RecordValue]) -> None:
37513783
f.write(encoded)
37523784

37533785

3754-
class Writer:
3786+
class Writer(_FileChecker[WriteSeekableBinStream]):
37553787
"""Provides write support for ESRI Shapefiles."""
37563788

3757-
W = TypeVar("W", bound=WriteSeekableBinStream)
3789+
# W = TypeVar("W", bound=WriteSeekableBinStream)
3790+
3791+
FileProto = WriteSeekableBinStream
3792+
new_file_mode = "wb+"
3793+
ext = ".shp"
3794+
ExceptionClass = ShapefileException
37583795

37593796
def __init__(
37603797
self,
@@ -3821,20 +3858,20 @@ def __init__(
38213858

38223859
@functools.cached_property
38233860
def shp(self) -> WriteSeekableBinStream:
3824-
return _ensure_file_obj(
3861+
return self._ensure_file_obj(
38253862
f=self._shp,
3826-
FileProto=WriteSeekableBinStream,
3827-
exit_stack=self.exit_stack,
3828-
new_file_mode="wb+",
3863+
# FileProto=WriteSeekableBinStream,
3864+
# exit_stack=self.exit_stack,
3865+
# new_file_mode="wb+",
38293866
)
38303867

38313868
@functools.cached_property
38323869
def shx(self) -> WriteSeekableBinStream:
3833-
return _ensure_file_obj(
3870+
return self._ensure_file_obj(
38343871
f=self._shx,
3835-
FileProto=WriteSeekableBinStream,
3836-
exit_stack=self.exit_stack,
3837-
new_file_mode="wb+",
3872+
# FileProto=WriteSeekableBinStream,
3873+
# exit_stack=self.exit_stack,
3874+
# new_file_mode="wb+",
38383875
)
38393876

38403877
@functools.cached_property

0 commit comments

Comments
 (0)