Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion langfun/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,13 +659,20 @@ def from_value(
) -> 'Template':
"""Creates a template object from a value."""
if isinstance(value, cls):
return value.clone(override=kwargs) if kwargs else value # pylint: disable=no-value-for-parameter
if kwargs:
lfun = value.clone(override=kwargs) # pylint: disable=no-value-for-parameter
lfun._referred_modalities = value._referred_modalities # pylint: disable=protected-access
return lfun
return value
if isinstance(value, str):
return cls(template_str=value, **kwargs)
if isinstance(value, Template):
lfun = cls(template_str=value.template_str, **kwargs) # pylint: disable=attribute-error
# So lfun could acccess all attributes from value.
lfun.sym_setparent(value)
# Assign _referred_modalities AFTER sym_setparent, since
# sym_setparent may trigger _on_bound which wipes this field.
lfun._referred_modalities = value._referred_modalities # pylint: disable=protected-access
return lfun
if message_lib.Message.is_convertible(type(value)):
value = message_lib.Message.from_value(value)
Expand Down
37 changes: 37 additions & 0 deletions langfun/core/template_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,43 @@ class MyTemplate(Template):
self.assertEqual(t2.template_str, t.template_str)
self.assertEqual(t2.x, 2)

def test_from_same_template_with_override_preserves_modalities(self):
"""Tests that clone(override=kwargs) preserves _referred_modalities."""

class CustomModality(modality.Modality):
content: str

def to_bytes(self):
return self.content.encode()

t = Template('{{image}}', image=CustomModality('foo'))
t._referred_modalities = {'foo': CustomModality('foo')}

# Clone with override (kwargs non-empty) must preserve modalities.
t2 = Template.from_value(t, image=CustomModality('bar'))
self.assertIsInstance(t2, Template)
self.assertIsNot(t2, t)
self.assertEqual(t2._referred_modalities, t._referred_modalities) # pylint: disable=protected-access

def test_from_different_template_with_modalities(self):
"""Tests that cross-class conversion preserves _referred_modalities."""

class CustomModality(modality.Modality):
content: str

def to_bytes(self):
return self.content.encode()

t = Template('{{image}}', image=CustomModality('foo'))
t._referred_modalities = {'foo': CustomModality('foo')}

class MyTemplate(Template):
pass

t2 = MyTemplate.from_value(t)
self.assertIsInstance(t2, MyTemplate)
self.assertEqual(t2._referred_modalities, t._referred_modalities) # pylint: disable=protected-access

def test_from_python_object(self):
t = Template.from_value(pg.Dict(x=1, y=2))
self.assertEqual(t.template_str, '{{input}}')
Expand Down
Loading