Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 27 additions & 30 deletions ufirestore/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def to_value_type(cls, value):
elif value.startswith("/g"):
typ = "geoPointValue"
value = value[2:].split(",")
value = {
"latitude": value[0],
"longitude": value[1]
}
value = {"latitude": value[0], "longitude": value[1]}
elif isinstance(value, str):
typ = "stringValue"
elif isinstance(value, bytes):
Expand All @@ -37,7 +34,9 @@ def to_value_type(cls, value):
return {typ: {"values": [cls.to_value_type(item) for item in value]}}
elif isinstance(value, dict):
typ = "mapValue"
return {typ: {"fields": {k: cls.to_value_type(v) for k, v in value.items()}}}
return {
typ: {"fields": {k: cls.to_value_type(v) for k, v in value.items()}}
}

return {typ: str(value)}

Expand Down Expand Up @@ -120,22 +119,25 @@ def cb(cur, s):

return self.cursor(path, cb)

def process(self, name):
def process(
self,
# name,
):
return {
"name": name,
# "name": name,
"fields": self.data
}

@classmethod
def from_raw(cls, raw):
print(raw)
fields = raw["fields"]
doc_data = {
"name": raw["name"],
"createTime": raw["createTime"],
"updateTime": raw["updateTime"]
"updateTime": raw["updateTime"],
}
doc_data.update(fields={k: cls.from_value_type(v)
for k, v in fields.items()})
doc_data.update(fields={k: cls.from_value_type(v) for k, v in fields.items()})
return FirebaseJson(doc_data)


Expand All @@ -150,33 +152,27 @@ class Query(FirebaseJson):
"array-contains": "ARRAY_CONTAINS",
"in": "IN",
"array-contains-any": "ARRAY_CONTAINS_ANY",
"not-in": "NOT_IN"
"not-in": "NOT_IN",
}

def __init__(self, *args, **kwargs):
self.num_filters = 0
super().__init__(*args, **kwargs)

def from_(self, collection_id, all_descendants=False):
self.add_item("from", {
"collectionId": collection_id,
"allDescendants": all_descendants
})
self.add_item(
"from", {"collectionId": collection_id, "allDescendants": all_descendants}
)
return self

def select(self, field):
self.add_item("select/fields", {
"fieldPath": field
})
self.add_item("select/fields", {"fieldPath": field})
return self

def order_by(self, field, direction="DESCENDING"):
self.add_item("orderBy", {
"field": {
"fieldPath": field
},
"direction": direction
})
self.add_item(
"orderBy", {"field": {"fieldPath": field}, "direction": direction}
)
return self

def limit(self, value):
Expand All @@ -192,13 +188,14 @@ def where(self, field, op, value):
self.remove("where/fieldFilter")
self.set("where/compositeFilter/op", "AND")
self.add_item("where/compositeFilter/filters", cur_filter)
self.add_item("where/compositeFilter/filters", {
"field": {
"fieldPath": field
self.add_item(
"where/compositeFilter/filters",
{
"field": {"fieldPath": field},
"op": self.OPERATIONS[op],
"value": self.to_value_type(value),
},
"op": self.OPERATIONS[op],
"value": self.to_value_type(value)
})
)
else:
self.set("where/fieldFilter/field/fieldPath", field)
self.set("where/fieldFilter/op", self.OPERATIONS[op])
Expand Down
57 changes: 31 additions & 26 deletions ufirestore/ufirestore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ujson
import urequests
import ujson # type: ignore
import _thread

import urequests # type: ignore

class FirestoreException(Exception):
def __init__(self, message, code=400):
Expand Down Expand Up @@ -34,46 +33,47 @@ def construct_url(resource_path=None):


def to_url_params(params=dict()):
return "?" + "&".join(
[(str(k) + "=" + str(v)) for k, v in params.items() if v is not None])
return "?" + "&".join([(f"{k}={v}") for k, v in params.items() if v is not None])


def get_resource_name(url):
return url[url.find("projects"):]


def send_request(path, method="get", params=dict(), data=None, dump=True):
def send_request(path, method="GET", params=dict(), data=None, dump=True):
headers = {}

if FIREBASE_GLOBAL_VAR.ACCESS_TOKEN:
headers["Authorization"] = "Bearer " + FIREBASE_GLOBAL_VAR.ACCESS_TOKEN

response = urequests.request(
method, path, params=params, headers=headers, json=data)
if method in ["POST", "PATCH", "PUT"]:
# headers["Accept"] = "application/json"
# headers["Content-Type"] = "application/json"
response = urequests.request(method, path, data=None, json=data, headers=headers)
else:
response = urequests.request(method, path, headers=headers)

if dump == True:
if response.status_code < 200 or response.status_code > 299:
print(response.text)
raise FirestoreException(response.reason, response.status_code)

json = response.json()
if json.get("error"):
error = json["error"]
jsonResponse = response.json()
if jsonResponse.get("error"):
error = jsonResponse["error"]
code = error["code"]
message = error["message"]
raise FirestoreException(message, code)
return json
return jsonResponse


class INTERNAL:
def patch(DOCUMENT_PATH, DOC, cb, update_mask=None):
PATH = construct_url(DOCUMENT_PATH)
LOCAL_PARAMS = to_url_params()
if update_mask:
for field in update_mask:
LOCAL_PARAMS += "updateMask.fieldPaths=" + field
DATA = DOC.process(get_resource_name(PATH))
LOCAL_OUTPUT = send_request(PATH+LOCAL_PARAMS, "post", data=DATA)
LOCAL_PARAMS = to_url_params() + "&".join([f"updateMask.fieldPaths={field}" for field in update_mask])
DATA = DOC.process() # name in here has been deprecated I believe
# https://firebase.google.com/docs/firestore/reference/rest/v1beta1/projects.databases.documents/patch
# name is added as a part of the path
# updateMask is a query parameter
LOCAL_OUTPUT = send_request(PATH+LOCAL_PARAMS, "PATCH", data=DATA)
if cb:
try:
return cb(LOCAL_OUTPUT)
Expand All @@ -85,8 +85,13 @@ def patch(DOCUMENT_PATH, DOC, cb, update_mask=None):
def create(COLLECTION_PATH, DOC, cb, document_id=None):
PATH = construct_url(COLLECTION_PATH)
PARAMS = {"documentId": document_id}
DATA = DOC.process(get_resource_name(PATH))
LOCAL_OUTPUT = send_request(PATH, "post", PARAMS, DATA)
if document_id is not None:
LOCAL_PARAMS = to_url_params(PARAMS)
else:
LOCAL_PARAMS = ""
# DATA = DOC.process(get_resource_name(PATH))
DATA = DOC.process()
LOCAL_OUTPUT = send_request(PATH+LOCAL_PARAMS, "POST", PARAMS, DATA)
if cb:
try:
return cb(LOCAL_OUTPUT)
Expand All @@ -101,7 +106,7 @@ def get(DOCUMENT_PATH, cb, mask=None):
if mask:
for field in mask:
LOCAL_PARAMS += "mask.fieldPaths=" + field
LOCAL_OUTPUT = send_request(PATH+LOCAL_PARAMS, "get")
LOCAL_OUTPUT = send_request(PATH+LOCAL_PARAMS, "GET")
if cb:
try:
return cb(LOCAL_OUTPUT)
Expand Down Expand Up @@ -167,7 +172,7 @@ def list_collection_ids(DOCUMENT_PATH, cb, page_size=None, page_token=None):
"pageSize": page_size,
"pageToken": page_token
}
LOCAL_OUTPUT = send_request(PATH, "post", data=DATA)
LOCAL_OUTPUT = send_request(PATH, "POST", data=DATA)
if cb:
try:
return cb(LOCAL_OUTPUT.get("collectionIds"),
Expand All @@ -183,7 +188,7 @@ def run_query(DOCUMENT_PATH, query, cb):
DATA = {
"structuredQuery": query.data
}
LOCAL_OUTPUT = send_request(PATH, "post", data=DATA)
LOCAL_OUTPUT = send_request(PATH, "POST", data=DATA)
if cb:
try:
return cb(LOCAL_OUTPUT.get("document"))
Expand Down Expand Up @@ -266,4 +271,4 @@ def run_query(PATH, query, bg=True, cb=None):
_thread.start_new_thread(
INTERNAL.run_query, [PATH, query, cb])
else:
return INTERNAL.run_query(PATH, query, cb)
return INTERNAL.run_query(PATH, query, cb)