Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ jobs:
opt-einsum \
nltk \
fvcore \
scikit-optimize
scikit-optimize \
flair
kill $KA
cd src/main/python
python -m unittest discover -s tests/scuro -p 'test_*.py' -v
27 changes: 26 additions & 1 deletion src/main/python/systemds/scuro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
AggregatedRepresentation,
)
from systemds.scuro.representations.average import Average
from systemds.scuro.representations.bert import Bert
from systemds.scuro.representations.bert import (
Bert,
RoBERTa,
DistillBERT,
ALBERT,
ELECTRA,
)
from systemds.scuro.representations.bow import BoW
from systemds.scuro.representations.concatenation import Concatenation
from systemds.scuro.representations.context import Context
Expand Down Expand Up @@ -101,6 +107,16 @@
from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
from systemds.scuro.representations.vgg import VGG19
from systemds.scuro.representations.clip import CLIPText, CLIPVisual
from systemds.scuro.representations.text_context import (
SentenceBoundarySplit,
OverlappingSplit,
)
from systemds.scuro.representations.text_context_with_indices import (
SentenceBoundarySplitIndices,
OverlappingSplitIndices,
)
from systemds.scuro.representations.elmo import ELMoRepresentation


__all__ = [
"BaseLoader",
Expand All @@ -113,6 +129,10 @@
"AggregatedRepresentation",
"Average",
"Bert",
"RoBERTa",
"DistillBERT",
"ALBERT",
"ELECTRA",
"BoW",
"Concatenation",
"Context",
Expand Down Expand Up @@ -177,4 +197,9 @@
"VGG19",
"CLIPVisual",
"CLIPText",
"SentenceBoundarySplit",
"OverlappingSplit",
"ELMoRepresentation",
"SentenceBoundarySplitIndices",
"OverlappingSplitIndices",
]
24 changes: 16 additions & 8 deletions src/main/python/systemds/scuro/drsearch/operator_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ class Registry:

_instance = None
_representations = {}
_context_operators = []
_context_operators = {}
_fusion_operators = []
_text_context_operators = []
_video_context_operators = []

def __new__(cls):
if not cls._instance:
Expand All @@ -60,8 +62,13 @@ def add_representation(
):
self._representations[modality].append(representation)

def add_context_operator(self, context_operator):
self._context_operators.append(context_operator)
def add_context_operator(self, context_operator, modality_type):
if not isinstance(modality_type, list):
modality_type = [modality_type]
for m_type in modality_type:
if not m_type in self._context_operators.keys():
self._context_operators[m_type] = []
self._context_operators[m_type].append(context_operator)

def add_fusion_operator(self, fusion_operator):
self._fusion_operators.append(fusion_operator)
Expand All @@ -76,9 +83,8 @@ def get_not_self_contained_representations(self, modality: ModalityType):
reps.append(rep)
return reps

def get_context_operators(self):
# TODO: return modality specific context operations
return self._context_operators
def get_context_operators(self, modality_type):
return self._context_operators[modality_type]

def get_fusion_operators(self):
return self._fusion_operators
Expand Down Expand Up @@ -121,13 +127,15 @@ def decorator(cls):
return decorator


def register_context_operator():
def register_context_operator(modality_type):
"""
Decorator to register a context operator.

@param modality_type: The modality type for which the context operator is to be registered
"""

def decorator(cls):
Registry().add_context_operator(cls)
Registry().add_context_operator(cls, modality_type)
return cls

return decorator
Expand Down
41 changes: 37 additions & 4 deletions src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def _get_not_self_contained_reps(self, modality_type):
)

@lru_cache(maxsize=32)
def _get_context_operators(self):
return self.operator_registry.get_context_operators()
def _get_context_operators(self, modality_type):
return self.operator_registry.get_context_operators(modality_type)

def store_results(self, file_name=None):
if file_name is None:
Expand Down Expand Up @@ -302,6 +302,39 @@ def _build_modality_dag(
current_node_id = rep_node_id
dags.append(builder.build(current_node_id))

if operator.needs_context:
context_operators = self._get_context_operators(modality.modality_type)
for context_op in context_operators:
if operator.initial_context_length is not None:
context_length = operator.initial_context_length

context_node_id = builder.create_operation_node(
context_op,
[leaf_id],
context_op(context_length).get_current_parameters(),
)
else:
context_node_id = builder.create_operation_node(
context_op,
[leaf_id],
context_op().get_current_parameters(),
)

context_rep_node_id = builder.create_operation_node(
operator.__class__,
[context_node_id],
operator.get_current_parameters(),
)

agg_operator = AggregatedRepresentation()
context_agg_node_id = builder.create_operation_node(
agg_operator.__class__,
[context_rep_node_id],
agg_operator.get_current_parameters(),
)

dags.append(builder.build(context_agg_node_id))

if not operator.self_contained:
not_self_contained_reps = self._get_not_self_contained_reps(
modality.modality_type
Expand Down Expand Up @@ -344,7 +377,7 @@ def _build_modality_dag(

def default_context_operators(self, modality, builder, leaf_id, current_node_id):
dags = []
context_operators = self._get_context_operators()
context_operators = self._get_context_operators(modality.modality_type)
for context_op in context_operators:
if (
modality.modality_type != ModalityType.TEXT
Expand All @@ -368,7 +401,7 @@ def default_context_operators(self, modality, builder, leaf_id, current_node_id)

def temporal_context_operators(self, modality, builder, leaf_id, current_node_id):
aggregators = self.operator_registry.get_representations(modality.modality_type)
context_operators = self._get_context_operators()
context_operators = self._get_context_operators(modality.modality_type)

dags = []
for agg in aggregators:
Expand Down
2 changes: 0 additions & 2 deletions src/main/python/systemds/scuro/modality/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
# under the License.
#
# -------------------------------------------------------------
from functools import reduce
from operator import or_
from typing import Union, List

from systemds.scuro.modality.type import ModalityType
Expand Down
18 changes: 12 additions & 6 deletions src/main/python/systemds/scuro/modality/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,12 @@ def update_base_metadata(cls, md, data, data_is_single_instance=True):
shape = data.shape
elif data_layout is DataLayout.NESTED_LEVEL:
if data_is_single_instance:
dtype = data.dtype
shape = data.shape
if isinstance(data, list):
dtype = type(data[0])
shape = (len(data), len(data[0]))
else:
dtype = data.dtype
shape = data.shape
else:
shape = data[0].shape
dtype = data[0].dtype
Expand Down Expand Up @@ -306,13 +310,15 @@ def get_data_layout(cls, data, data_is_single_instance):
return None

if data_is_single_instance:
if (
if (isinstance(data, list) and not isinstance(data[0], str)) or (
isinstance(data, np.ndarray) and data.ndim == 1
):
return DataLayout.SINGLE_LEVEL
elif (
isinstance(data, list)
or isinstance(data, np.ndarray)
and data.ndim == 1
or isinstance(data, torch.Tensor)
):
return DataLayout.SINGLE_LEVEL
elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor):
return DataLayout.NESTED_LEVEL

if isinstance(data[0], list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def execute(self, modality):
max_len = 0
for i, instance in enumerate(modality.data):
data.append([])
if isinstance(instance, np.ndarray):
if isinstance(instance, np.ndarray) or isinstance(instance, list):
if (
modality.modality_type == ModalityType.IMAGE
or modality.modality_type == ModalityType.VIDEO
Expand Down
Loading
Loading