Skip to content

Commit b186c8f

Browse files
committed
feat: Add rollback_to_snapshot to ManageSnapshots API
1 parent b0880c8 commit b186c8f

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

pyiceberg/table/update/snapshot.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
Snapshot,
6565
SnapshotSummaryCollector,
6666
Summary,
67+
ancestors_of,
6768
update_snapshot_summaries,
6869
)
6970
from pyiceberg.table.update import (
@@ -985,6 +986,40 @@ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | N
985986
self._transaction._stage(update, requirement)
986987
return self
987988

989+
def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
990+
"""Rollback the table to the given snapshot id.
991+
992+
The snapshot needs to be an ancestor of the current table state.
993+
994+
Args:
995+
snapshot_id (int): rollback to this snapshot_id that used to be current.
996+
997+
Returns:
998+
This for method chaining
999+
1000+
Raises:
1001+
ValueError: If the snapshot does not exist or is not an ancestor of the current table state.
1002+
"""
1003+
if not self._transaction.table_metadata.snapshot_by_id(snapshot_id):
1004+
raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")
1005+
1006+
if not self._is_current_ancestor(snapshot_id):
1007+
raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")
1008+
1009+
return self.set_current_snapshot(snapshot_id=snapshot_id)
1010+
1011+
def _is_current_ancestor(self, snapshot_id: int) -> bool:
1012+
return snapshot_id in self._current_ancestors()
1013+
1014+
def _current_ancestors(self) -> set[int]:
1015+
return {
1016+
a.snapshot_id
1017+
for a in ancestors_of(
1018+
self._transaction.table_metadata.current_snapshot(),
1019+
self._transaction.table_metadata,
1020+
)
1021+
}
1022+
9881023

9891024
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
9901025
"""Expire snapshots by ID.

tests/integration/test_snapshot_operations.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,44 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import uuid
18+
from collections.abc import Generator
19+
20+
import pyarrow as pa
1721
import pytest
1822

1923
from pyiceberg.catalog import Catalog
24+
from pyiceberg.table import Table
2025
from pyiceberg.table.refs import SnapshotRef
2126

2227

28+
@pytest.fixture
29+
def table_with_snapshots(session_catalog: Catalog) -> Generator[Table, None, None]:
30+
session_catalog.create_namespace_if_not_exists("default")
31+
identifier = f"default.test_table_snapshot_ops_{uuid.uuid4().hex[:8]}"
32+
33+
arrow_schema = pa.schema(
34+
[
35+
pa.field("id", pa.int64(), nullable=False),
36+
pa.field("data", pa.string(), nullable=True),
37+
]
38+
)
39+
40+
tbl = session_catalog.create_table(identifier=identifier, schema=arrow_schema)
41+
42+
data1 = pa.Table.from_pylist([{"id": 1, "data": "a"}, {"id": 2, "data": "b"}], schema=arrow_schema)
43+
tbl.append(data1)
44+
45+
data2 = pa.Table.from_pylist([{"id": 3, "data": "c"}, {"id": 4, "data": "d"}], schema=arrow_schema)
46+
tbl.append(data2)
47+
48+
tbl = session_catalog.load_table(identifier)
49+
50+
yield tbl
51+
52+
session_catalog.drop_table(identifier)
53+
54+
2355
@pytest.mark.integration
2456
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
2557
def test_create_tag(catalog: Catalog) -> None:
@@ -160,3 +192,79 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
160192
tbl = catalog.load_table(identifier)
161193
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
162194
assert tbl.metadata.refs.get(tag_name, None) is None
195+
196+
197+
@pytest.mark.integration
198+
def test_rollback_to_snapshot(table_with_snapshots: Table) -> None:
199+
history = table_with_snapshots.history()
200+
assert len(history) >= 2
201+
202+
ancestor_snapshot_id = history[-2].snapshot_id
203+
204+
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()
205+
206+
updated = table_with_snapshots.current_snapshot()
207+
assert updated is not None
208+
assert updated.snapshot_id == ancestor_snapshot_id
209+
210+
211+
@pytest.mark.integration
212+
def test_rollback_to_current_snapshot(table_with_snapshots: Table) -> None:
213+
current = table_with_snapshots.current_snapshot()
214+
assert current is not None
215+
216+
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=current.snapshot_id).commit()
217+
218+
updated = table_with_snapshots.current_snapshot()
219+
assert updated is not None
220+
assert updated.snapshot_id == current.snapshot_id
221+
222+
223+
@pytest.mark.integration
224+
def test_rollback_to_snapshot_chained_with_tag(table_with_snapshots: Table) -> None:
225+
history = table_with_snapshots.history()
226+
assert len(history) >= 2
227+
228+
ancestor_snapshot_id = history[-2].snapshot_id
229+
tag_name = "my-tag"
230+
231+
(
232+
table_with_snapshots.manage_snapshots()
233+
.create_tag(snapshot_id=ancestor_snapshot_id, tag_name=tag_name)
234+
.rollback_to_snapshot(snapshot_id=ancestor_snapshot_id)
235+
.commit()
236+
)
237+
238+
updated = table_with_snapshots.current_snapshot()
239+
assert updated is not None
240+
assert updated.snapshot_id == ancestor_snapshot_id
241+
assert table_with_snapshots.metadata.refs[tag_name] == SnapshotRef(snapshot_id=ancestor_snapshot_id, snapshot_ref_type="tag")
242+
243+
244+
@pytest.mark.integration
245+
def test_rollback_to_snapshot_not_ancestor(table_with_snapshots: Table) -> None:
246+
history = table_with_snapshots.history()
247+
assert len(history) >= 2
248+
249+
snapshot_a = history[-2].snapshot_id
250+
251+
branch_name = "my-branch"
252+
table_with_snapshots.manage_snapshots().create_branch(snapshot_id=snapshot_a, branch_name=branch_name).commit()
253+
254+
data = pa.Table.from_pylist([{"id": 5, "data": "e"}], schema=table_with_snapshots.schema().as_arrow())
255+
table_with_snapshots.append(data, branch=branch_name)
256+
257+
snapshot_c = table_with_snapshots.metadata.snapshot_by_name(branch_name)
258+
assert snapshot_c is not None
259+
assert snapshot_c.snapshot_id != snapshot_a
260+
261+
with pytest.raises(ValueError, match="not an ancestor"):
262+
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c.snapshot_id).commit()
263+
264+
265+
@pytest.mark.integration
266+
def test_rollback_to_snapshot_unknown_id(table_with_snapshots: Table) -> None:
267+
invalid_snapshot_id = 1234567890000
268+
269+
with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"):
270+
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit()

0 commit comments

Comments
 (0)