Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

"""Common helpers shared across Google Cloud Firestore modules."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
import datetime
import json

import google
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
from google.api_core import gapic_v1
from google.protobuf import struct_pb2
from google.protobuf import struct_pb2 # type: ignore
from google.type import latlng_pb2 # type: ignore
import grpc # type: ignore

Expand All @@ -41,7 +43,6 @@
Generator,
Iterator,
List,
NoReturn,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -70,17 +71,16 @@
}


class GeoPoint(object):
@dataclass
class GeoPoint:
"""Simple container for a geo point value.

Args:
latitude (float): Latitude of a point.
longitude (float): Longitude of a point.
"""

def __init__(self, latitude, longitude) -> None:
self.latitude = latitude
self.longitude = longitude
latitude: float
longitude: float

def to_protobuf(self) -> latlng_pb2.LatLng:
"""Convert the current object to protobuf.
Expand Down Expand Up @@ -495,7 +495,7 @@ def __init__(self, document_data) -> None:
self.increments = {}
self.minimums = {}
self.maximums = {}
self.set_fields = {}
self.set_fields: Dict[Any, Any] = {}
self.empty_document = False

prefix_path = FieldPath()
Expand Down Expand Up @@ -557,7 +557,7 @@ def transform_paths(self):
+ list(self.minimums)
)

def _get_update_mask(self, allow_empty_mask=False) -> None:
def _get_update_mask(self, allow_empty_mask=False):
return None

def get_update_pb(
Expand Down Expand Up @@ -721,9 +721,9 @@ class DocumentExtractorForMerge(DocumentExtractor):

def __init__(self, document_data) -> None:
super(DocumentExtractorForMerge, self).__init__(document_data)
self.data_merge = []
self.transform_merge = []
self.merge = []
self.data_merge: List[Any] = []
self.transform_merge: List[Any] = []
self.merge: List[Any] = []

def _apply_merge_all(self) -> None:
self.data_merge = sorted(self.field_paths + self.deleted_fields)
Expand Down Expand Up @@ -777,7 +777,7 @@ def _apply_merge_paths(self, merge) -> None:
self.data_merge.append(field_path)

# Clear out data for fields not merged.
merged_set_fields = {}
merged_set_fields: Dict[str, Any] = {}
for field_path in self.data_merge:
value = get_field_value(self.document_data, field_path)
set_field_value(merged_set_fields, field_path, value)
Expand Down Expand Up @@ -1007,10 +1007,11 @@ def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]:
return [("google-cloud-resource-prefix", prefix)]


class WriteOption(object):
class WriteOption(ABC, object):
"""Option used to assert a condition on a write operation."""

def modify_write(self, write, no_create_msg=None) -> NoReturn:
@abstractmethod
def modify_write(self, write, **unused_kwargs) -> None:
"""Modify a ``Write`` protobuf based on the state of this write option.

This is a virtual method intended to be implemented by subclasses.
Expand Down Expand Up @@ -1142,7 +1143,7 @@ def compare_timestamps(
def deserialize_bundle(
serialized: Union[str, bytes],
client: "google.cloud.firestore_v1.client.BaseClient", # type: ignore
) -> "google.cloud.firestore_bundle.FirestoreBundle": # type: ignore
) -> Optional["google.cloud.firestore_bundle.FirestoreBundle"]: # type: ignore
"""Inverse operation to a `FirestoreBundle` instance's `build()` method.

Args:
Expand Down Expand Up @@ -1226,9 +1227,9 @@ def deserialize_bundle(
raise ValueError("Unexpected end to serialized FirestoreBundle")

# Now, finally add the metadata element
bundle._add_bundle_element(
metadata_bundle_element,
client=client,
bundle._add_bundle_element( # type: ignore
metadata_bundle_element, # type: ignore
client=client, # type: ignore
type="metadata", # type: ignore
)

Expand Down Expand Up @@ -1297,3 +1298,4 @@ def _get_document_from_bundle(
bundled_doc = bundle.documents.get(document_id)
if bundled_doc:
return bundled_doc.snapshot
return None
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from google.cloud.firestore_v1.transaction import Transaction


class AsyncCollectionReference(BaseCollectionReference):
class AsyncCollectionReference(BaseCollectionReference[async_query.AsyncQuery]):
"""A reference to a collection in a Firestore database.

The collection may already exist or this class can facilitate creation
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/async_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from google.cloud.firestore_v1 import _helpers
from google.cloud.firestore_v1.types import write
from google.protobuf.timestamp_pb2 import Timestamp
from google.protobuf.timestamp_pb2 import Timestamp # type: ignore
from typing import AsyncGenerator, Iterable


Expand Down
12 changes: 7 additions & 5 deletions google/cloud/firestore_v1/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,15 @@ def _rpc_metadata(self):

return self._rpc_metadata_internal

def collection(self, *collection_path) -> BaseCollectionReference:
def collection(self, *collection_path) -> BaseCollectionReference[BaseQuery]:
raise NotImplementedError

def collection_group(self, collection_id: str) -> BaseQuery:
raise NotImplementedError

def _get_collection_reference(self, collection_id: str) -> BaseCollectionReference:
def _get_collection_reference(
self, collection_id: str
) -> BaseCollectionReference[BaseQuery]:
"""Checks validity of collection_id and then uses subclasses collection implementation.

Args:
Expand Down Expand Up @@ -325,7 +327,7 @@ def _document_path_helper(self, *document_path) -> List[str]:

def recursive_delete(
self,
reference: Union[BaseCollectionReference, BaseDocumentReference],
reference: Union[BaseCollectionReference[BaseQuery], BaseDocumentReference],
bulk_writer: Optional["BulkWriter"] = None, # type: ignore
) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -459,8 +461,8 @@ def collections(
retry: retries.Retry = None,
timeout: float = None,
) -> Union[
AsyncGenerator[BaseCollectionReference, Any],
Generator[BaseCollectionReference, Any, Any],
AsyncGenerator[BaseCollectionReference[BaseQuery], Any],
Generator[BaseCollectionReference[BaseQuery], Any, Any],
]:
raise NotImplementedError

Expand Down
27 changes: 14 additions & 13 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AsyncGenerator,
Coroutine,
Generator,
Generic,
AsyncIterator,
Iterator,
Iterable,
Expand All @@ -38,13 +39,13 @@

# Types needed only for Type Hints
from google.cloud.firestore_v1.base_document import DocumentSnapshot
from google.cloud.firestore_v1.base_query import BaseQuery
from google.cloud.firestore_v1.base_query import QueryType
from google.cloud.firestore_v1.transaction import Transaction

_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"


class BaseCollectionReference(object):
class BaseCollectionReference(Generic[QueryType]):
"""A reference to a collection in a Firestore database.

The collection may already exist or this class can facilitate creation
Expand Down Expand Up @@ -108,7 +109,7 @@ def parent(self):
parent_path = self._path[:-1]
return self._client.document(*parent_path)

def _query(self) -> BaseQuery:
def _query(self) -> QueryType:
raise NotImplementedError

def _aggregation_query(self) -> BaseAggregationQuery:
Expand Down Expand Up @@ -215,10 +216,10 @@ def list_documents(
]:
raise NotImplementedError

def recursive(self) -> "BaseQuery":
def recursive(self) -> QueryType:
return self._query().recursive()

def select(self, field_paths: Iterable[str]) -> BaseQuery:
def select(self, field_paths: Iterable[str]) -> QueryType:
"""Create a "select" query with this collection as parent.

See
Expand All @@ -244,7 +245,7 @@ def where(
value=None,
*,
filter=None
) -> BaseQuery:
) -> QueryType:
"""Create a "where" query with this collection as parent.

See
Expand Down Expand Up @@ -290,7 +291,7 @@ def where(
else:
return query.where(filter=filter)

def order_by(self, field_path: str, **kwargs) -> BaseQuery:
def order_by(self, field_path: str, **kwargs) -> QueryType:
"""Create an "order by" query with this collection as parent.

See
Expand All @@ -312,7 +313,7 @@ def order_by(self, field_path: str, **kwargs) -> BaseQuery:
query = self._query()
return query.order_by(field_path, **kwargs)

def limit(self, count: int) -> BaseQuery:
def limit(self, count: int) -> QueryType:
"""Create a limited query with this collection as parent.

.. note::
Expand Down Expand Up @@ -355,7 +356,7 @@ def limit_to_last(self, count: int):
query = self._query()
return query.limit_to_last(count)

def offset(self, num_to_skip: int) -> BaseQuery:
def offset(self, num_to_skip: int) -> QueryType:
"""Skip to an offset in a query with this collection as parent.

See
Expand All @@ -375,7 +376,7 @@ def offset(self, num_to_skip: int) -> BaseQuery:

def start_at(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""Start query at a cursor with this collection as parent.

See
Expand All @@ -398,7 +399,7 @@ def start_at(

def start_after(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""Start query after a cursor with this collection as parent.

See
Expand All @@ -421,7 +422,7 @@ def start_after(

def end_before(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""End query before a cursor with this collection as parent.

See
Expand All @@ -444,7 +445,7 @@ def end_before(

def end_at(
self, document_fields: Union[DocumentSnapshot, dict, list, tuple]
) -> BaseQuery:
) -> QueryType:
"""End query at a cursor with this collection as parent.

See
Expand Down
Loading