Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import timeit
from collections.abc import Callable
from os import PathLike
from pathlib import Path
from tempfile import TemporaryFile as TempF
from typing import Iterable, Union, cast
Expand Down Expand Up @@ -50,14 +51,14 @@ def benchmark(
shapeRecords = collections.defaultdict(list)


def open_shapefile_with_PyShp(target: Union[str, os.PathLike]):
def open_shapefile_with_PyShp(target: Union[str, PathLike]):
with shapefile.Reader(target) as r:
fields[target] = r.fields
for shapeRecord in r.iterShapeRecords():
shapeRecords[target].append(shapeRecord)


def write_shapefile_with_PyShp(target: Union[str, os.PathLike]):
def write_shapefile_with_PyShp(target: Union[str, PathLike]):
with TempF("wb") as shp, TempF("wb") as dbf, TempF("wb") as shx:
with shapefile.Writer(shp=shp, dbf=dbf, shx=shx) as w: # type: ignore [arg-type]
for field_info_tuple in fields[target]:
Expand Down
15 changes: 8 additions & 7 deletions src/shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
import zipfile
from datetime import date
from os import PathLike
from struct import Struct, calcsize, error, pack, unpack
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -159,7 +160,7 @@ def read(self, size: int = -1) -> bytes: ...


# File name, file object or anything with a read() method that returns bytes.
BinaryFileT = Union[str, IO[bytes]]
BinaryFileT = Union[str, PathLike[Any], IO[bytes]]
BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream]

FieldTypeT = Literal["C", "D", "F", "L", "M", "N"]
Expand Down Expand Up @@ -341,11 +342,11 @@ class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection):


@overload
def fsdecode_if_pathlike(path: os.PathLike[Any]) -> str: ...
def fsdecode_if_pathlike(path: PathLike[Any]) -> str: ...
@overload
def fsdecode_if_pathlike(path: T) -> T: ...
def fsdecode_if_pathlike(path: Any) -> Any:
if isinstance(path, os.PathLike):
if isinstance(path, PathLike):
return os.fsdecode(path) # str

return path
Expand Down Expand Up @@ -2243,7 +2244,7 @@ def _assert_ext_is_supported(self, ext: str) -> None:

def __init__(
self,
shapefile_path: Union[str, os.PathLike[Any]] = "",
shapefile_path: Union[str, PathLike[Any]] = "",
/,
*,
encoding: str = "utf-8",
Expand Down Expand Up @@ -2411,7 +2412,7 @@ def __init__(
return

if shp is not _NO_SHP_SENTINEL:
shp = cast(Union[str, IO[bytes], None], shp)
shp = cast(Union[str, PathLike[Any], IO[bytes], None], shp)
self.shp = self.__seek_0_on_file_obj_wrap_or_open_from_name("shp", shp)
self.shx = self.__seek_0_on_file_obj_wrap_or_open_from_name("shx", shx)

Expand All @@ -2432,7 +2433,7 @@ def __seek_0_on_file_obj_wrap_or_open_from_name(
if file_ is None:
return None

if isinstance(file_, str):
if isinstance(file_, (str, PathLike)):
baseName, __ = os.path.splitext(file_)
return self._load_constituent_file(baseName, ext)

Expand Down Expand Up @@ -3235,7 +3236,7 @@ class Writer:

def __init__(
self,
target: Union[str, os.PathLike[Any], None] = None,
target: Union[str, PathLike[Any], None] = None,
shapeType: Optional[int] = None,
autoBalance: bool = False,
*,
Expand Down
59 changes: 57 additions & 2 deletions test_shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# our imports
import shapefile

shapefiles_dir = Path(__file__).parent / "shapefiles"

# define various test shape tuples of (type, points, parts indexes, and expected geo interface output)
geo_interface_tests = [
(
Expand Down Expand Up @@ -719,8 +721,7 @@ def test_reader_pathlike():
"""
Assert that path-like objects can be read.
"""
base = Path("shapefiles")
with shapefile.Reader(base / "blockgroups") as sf:
with shapefile.Reader(shapefiles_dir / "blockgroups") as sf:
assert len(sf) == 663


Expand All @@ -736,6 +737,18 @@ def test_reader_dbf_only():
assert record[1:3] == ["060750601001", 4715]


def test_reader_dbf_only_from_Path():
"""
Assert that specifying just the
dbf argument to the shapefile reader as a Path
reads just the dbf file.
"""
with shapefile.Reader(dbf=shapefiles_dir / "blockgroups.dbf") as sf:
assert len(sf) == 663
record = sf.record(3)
assert record[1:3] == ["060750601001", 4715]


def test_reader_shp_shx_only():
"""
Assert that specifying just the
Expand All @@ -750,6 +763,20 @@ def test_reader_shp_shx_only():
assert len(shape.points) == 173


def test_reader_shp_shx_only_from_Paths():
"""
Assert that specifying just the
shp and shx argument to the shapefile reader as Paths
reads just the shp and shx file.
"""
with shapefile.Reader(
shp=shapefiles_dir / "blockgroups.shp", shx=shapefiles_dir / "blockgroups.shx"
) as sf:
assert len(sf) == 663
shape = sf.shape(3)
assert len(shape.points) == 173


def test_reader_shp_dbf_only():
"""
Assert that specifying just the
Expand All @@ -766,6 +793,22 @@ def test_reader_shp_dbf_only():
assert record[1:3] == ["060750601001", 4715]


def test_reader_shp_dbf_only_from_Paths():
"""
Assert that specifying just the
shp and shx argument to the shapefile reader as Paths
reads just the shp and dbf file.
"""
with shapefile.Reader(
shp=shapefiles_dir / "blockgroups.shp", dbf=shapefiles_dir / "blockgroups.dbf"
) as sf:
assert len(sf) == 663
shape = sf.shape(3)
assert len(shape.points) == 173
record = sf.record(3)
assert record[1:3] == ["060750601001", 4715]


def test_reader_shp_only():
"""
Assert that specifying just the
Expand All @@ -778,6 +821,18 @@ def test_reader_shp_only():
assert len(shape.points) == 173


def test_reader_shp_only_from_Path():
"""
Assert that specifying just the
shp argument to the shapefile reader as a Path
reads just the shp file (shx optional).
"""
with shapefile.Reader(shp=shapefiles_dir / "blockgroups.shp") as sf:
assert len(sf) == 663
shape = sf.shape(3)
assert len(shape.points) == 173


def test_reader_filelike_dbf_only():
"""
Assert that specifying just the
Expand Down
Loading