Skip to content
134 changes: 116 additions & 18 deletions specifyweb/backend/trees/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from specifyweb.celery_tasks import LogErrorsTask, app
import specifyweb.specify.models as spmodels
from specifyweb.backend.trees.utils import get_models, SPECIFY_TREES, TREE_ROOT_NODES
from specifyweb.backend.trees.extras import renumber_tree, set_fullnames

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,20 +137,105 @@ def __init__(self, tree_type: str, tree_name: str):
self.tree_def_model, self.tree_rank_model, self.tree_node_model = get_models(tree_type)

self.tree_def = self.tree_def_model.objects.get(name=tree_name)
self.tree_def_item_map = self.create_rank_map()

self.create_rank_map()
self.root_parent = self.tree_node_model.objects.filter(
definitionitem__rankid=0,
definition=self.tree_def
).first()

self.counter = 0
self.batch_size = 1000

def create_rank_map(self):
"""Rank lookup map to reduce queries"""
return {
rank.name: rank
for rank in self.tree_rank_model.objects.filter(treedef=self.tree_def)
}
ranks = list(self.tree_rank_model.objects.filter(treedef=self.tree_def))
self.tree_def_item_map = {rank.name: rank for rank in ranks}
# Buffers for batches
self.rankid_map = {rank.rankid: rank for rank in ranks}
self.buffers = {rank.rankid: {} for rank in ranks}
self.created = {rank.rankid: {} for rank in ranks}

def add_node_to_buffer(self, node, rank_id, row_id):
"""Add node to the current batch of nodes to be created"""
if rank_id not in self.buffers:
self.buffers[rank_id] = {}
self.created[rank_id] = {}
self.buffers[rank_id][row_id] = node
return node

def get_node_in_buffer(self, rank_id: int, name: str):
"""Gets a node if its already in the current batch's buffer. Prevents duplication within a batch."""
# Check for node in buffer, return node
buffer = self.buffers.get(rank_id, {})
for node in buffer.values():
if node.name == name:
return node
return None

def get_existing_node_id(self, rank_id: int, name: str) -> Optional[int]:
"""Gets a node's id if it has already been created. Prevents duplication across an entire import."""
# Check for existing id, return id
created_in_rank = self.created.get(rank_id)
if created_in_rank:
return created_in_rank.get(name)
return None

def flush(self, force=False):
"""Flushes this batch's buffer if the batch is complete. Bulk creates the nodes in a complete batch."""
self.counter += 1
if not (force or self.counter > self.batch_size):
return
logger.debug(f"Batch creating {self.batch_size} rows.")

# Go through ranks in ascending order and bulk create nodes
ordered_rank_ids = sorted(self.buffers.keys())
for rank_id in ordered_rank_ids:
logger.debug(f"On rank {rank_id}")
buffer = self.buffers.get(rank_id, {})

rank = self.rankid_map.get(rank_id)
if rank is None:
# Can't create nodes because this rank doesn't exist
continue

def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: dict[str, RankMappingConfiguration]):
nodes_to_create = []
# Update the nodes' parents to a saved version of their parents
for row_id, node in list(buffer.items()):
parent = getattr(node, 'parent', None)
parent_id = getattr(node, 'parent_id', None)
if parent is not None and getattr(parent, 'pk', None) is None:
saved_parent_id = self.created[parent.rankid].get(parent.name)
# Handle root
if not saved_parent_id and parent.name == getattr(self.root_parent, 'name', None):
saved_parent_id = self.root_parent.id
if saved_parent_id:
node.parent = None
node.parent_id = saved_parent_id

# Create node if its parent has been created
if getattr(node.parent, 'pk', None) is not None or getattr(node, 'parent_id', None) is not None:
nodes_to_create.append(node)
else:
logger.warning(f"Could not create {node.name} because a valid parent could not be resolved. {parent_id}, {str(parent)}")

if nodes_to_create:
self.tree_node_model.objects.bulk_create(nodes_to_create, ignore_conflicts=True)

# Store the ids of the nodes were created in this batch
created_names = [n.name for n in nodes_to_create]
created_nodes = self.tree_node_model.objects.filter(
definition=self.tree_def,
definitionitem=rank,
name__in=created_names
)
self.created[rank_id].update({n.name: n.id for n in created_nodes})

self.buffers[rank_id] = {}

self.counter = 0

def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: dict[str, RankMappingConfiguration], row_id: int):
"""
Given one CSV row and a column mapping / rank configuration dictionary,
walk through the 'ranks' in order, creating or updating each tree record and linking
Expand All @@ -158,6 +244,7 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di
tree_node_model = context.tree_node_model
tree_def = context.tree_def
parent = context.root_parent
parent_id = None
rank_id = 10

for rank_mapping in tree_cfg['ranks']:
Expand Down Expand Up @@ -187,27 +274,32 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di
continue

# Create the node at this rank if it isn't already there.
obj = tree_node_model.objects.filter(
name=record_name,
fullname=record_name,
definition=tree_def,
definitionitem=tree_def_item,
parent=parent,
).first()
if obj is None:
buffered = context.get_node_in_buffer(tree_def_item.rankid, record_name)
existing_id = context.get_existing_node_id(tree_def_item.rankid, record_name)
if existing_id is not None:
parent_id = existing_id
parent = None
elif buffered is not None:
parent_id = None
parent = buffered
else:
data = {
'name': record_name,
'fullname': record_name,
'definition': tree_def,
'definitionitem': tree_def_item,
'parent': parent,
'rankid': tree_def_item.rankid,
**defaults
}
if parent is not None:
data['parent'] = parent
elif parent_id is not None:
data['parent_id'] = parent_id
obj = tree_node_model(**data)
obj.save(skip_tree_extras=True)
obj = context.add_node_to_buffer(obj, tree_def_item.rankid, row_id)

parent = obj
parent = obj
parent_id = None
rank_id += 10

@app.task(base=LogErrorsTask, bind=True)
Expand Down Expand Up @@ -292,8 +384,14 @@ def progress(cur: int, additional_total: int=0) -> None:
progress(0, total_rows)

for row in stream_csv_from_url(url):
add_default_tree_record(context, row, tree_cfg)
add_default_tree_record(context, row, tree_cfg, current)
context.flush()
progress(1, 0)
context.flush(force=True)

# Finalize Tree
renumber_tree(tree_type)
set_fullnames(tree_def)
except Exception as e:
if specify_user_id and specify_collection_id:
Message.objects.create(
Expand Down