|
14 | 14 | # KIND, either express or implied. See the License for the |
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
| 17 | +import uuid |
| 18 | +from collections.abc import Generator |
| 19 | + |
| 20 | +import pyarrow as pa |
17 | 21 | import pytest |
18 | 22 |
|
19 | 23 | from pyiceberg.catalog import Catalog |
| 24 | +from pyiceberg.table import Table |
20 | 25 | from pyiceberg.table.refs import SnapshotRef |
21 | 26 |
|
22 | 27 |
|
| 28 | +@pytest.fixture |
| 29 | +def table_with_snapshots(session_catalog: Catalog) -> Generator[Table, None, None]: |
| 30 | + session_catalog.create_namespace_if_not_exists("default") |
| 31 | + identifier = f"default.test_table_snapshot_ops_{uuid.uuid4().hex[:8]}" |
| 32 | + |
| 33 | + arrow_schema = pa.schema( |
| 34 | + [ |
| 35 | + pa.field("id", pa.int64(), nullable=False), |
| 36 | + pa.field("data", pa.string(), nullable=True), |
| 37 | + ] |
| 38 | + ) |
| 39 | + |
| 40 | + tbl = session_catalog.create_table(identifier=identifier, schema=arrow_schema) |
| 41 | + |
| 42 | + data1 = pa.Table.from_pylist([{"id": 1, "data": "a"}, {"id": 2, "data": "b"}], schema=arrow_schema) |
| 43 | + tbl.append(data1) |
| 44 | + |
| 45 | + data2 = pa.Table.from_pylist([{"id": 3, "data": "c"}, {"id": 4, "data": "d"}], schema=arrow_schema) |
| 46 | + tbl.append(data2) |
| 47 | + |
| 48 | + tbl = session_catalog.load_table(identifier) |
| 49 | + |
| 50 | + yield tbl |
| 51 | + |
| 52 | + session_catalog.drop_table(identifier) |
| 53 | + |
| 54 | + |
23 | 55 | @pytest.mark.integration |
24 | 56 | @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) |
25 | 57 | def test_create_tag(catalog: Catalog) -> None: |
@@ -160,3 +192,79 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None: |
160 | 192 | tbl = catalog.load_table(identifier) |
161 | 193 | tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() |
162 | 194 | assert tbl.metadata.refs.get(tag_name, None) is None |
| 195 | + |
| 196 | + |
| 197 | +@pytest.mark.integration |
| 198 | +def test_rollback_to_snapshot(table_with_snapshots: Table) -> None: |
| 199 | + history = table_with_snapshots.history() |
| 200 | + assert len(history) >= 2 |
| 201 | + |
| 202 | + ancestor_snapshot_id = history[-2].snapshot_id |
| 203 | + |
| 204 | + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit() |
| 205 | + |
| 206 | + updated = table_with_snapshots.current_snapshot() |
| 207 | + assert updated is not None |
| 208 | + assert updated.snapshot_id == ancestor_snapshot_id |
| 209 | + |
| 210 | + |
| 211 | +@pytest.mark.integration |
| 212 | +def test_rollback_to_current_snapshot(table_with_snapshots: Table) -> None: |
| 213 | + current = table_with_snapshots.current_snapshot() |
| 214 | + assert current is not None |
| 215 | + |
| 216 | + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=current.snapshot_id).commit() |
| 217 | + |
| 218 | + updated = table_with_snapshots.current_snapshot() |
| 219 | + assert updated is not None |
| 220 | + assert updated.snapshot_id == current.snapshot_id |
| 221 | + |
| 222 | + |
| 223 | +@pytest.mark.integration |
| 224 | +def test_rollback_to_snapshot_chained_with_tag(table_with_snapshots: Table) -> None: |
| 225 | + history = table_with_snapshots.history() |
| 226 | + assert len(history) >= 2 |
| 227 | + |
| 228 | + ancestor_snapshot_id = history[-2].snapshot_id |
| 229 | + tag_name = "my-tag" |
| 230 | + |
| 231 | + ( |
| 232 | + table_with_snapshots.manage_snapshots() |
| 233 | + .create_tag(snapshot_id=ancestor_snapshot_id, tag_name=tag_name) |
| 234 | + .rollback_to_snapshot(snapshot_id=ancestor_snapshot_id) |
| 235 | + .commit() |
| 236 | + ) |
| 237 | + |
| 238 | + updated = table_with_snapshots.current_snapshot() |
| 239 | + assert updated is not None |
| 240 | + assert updated.snapshot_id == ancestor_snapshot_id |
| 241 | + assert table_with_snapshots.metadata.refs[tag_name] == SnapshotRef(snapshot_id=ancestor_snapshot_id, snapshot_ref_type="tag") |
| 242 | + |
| 243 | + |
| 244 | +@pytest.mark.integration |
| 245 | +def test_rollback_to_snapshot_not_ancestor(table_with_snapshots: Table) -> None: |
| 246 | + history = table_with_snapshots.history() |
| 247 | + assert len(history) >= 2 |
| 248 | + |
| 249 | + snapshot_a = history[-2].snapshot_id |
| 250 | + |
| 251 | + branch_name = "my-branch" |
| 252 | + table_with_snapshots.manage_snapshots().create_branch(snapshot_id=snapshot_a, branch_name=branch_name).commit() |
| 253 | + |
| 254 | + data = pa.Table.from_pylist([{"id": 5, "data": "e"}], schema=table_with_snapshots.schema().as_arrow()) |
| 255 | + table_with_snapshots.append(data, branch=branch_name) |
| 256 | + |
| 257 | + snapshot_c = table_with_snapshots.metadata.snapshot_by_name(branch_name) |
| 258 | + assert snapshot_c is not None |
| 259 | + assert snapshot_c.snapshot_id != snapshot_a |
| 260 | + |
| 261 | + with pytest.raises(ValueError, match="not an ancestor"): |
| 262 | + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c.snapshot_id).commit() |
| 263 | + |
| 264 | + |
| 265 | +@pytest.mark.integration |
| 266 | +def test_rollback_to_snapshot_unknown_id(table_with_snapshots: Table) -> None: |
| 267 | + invalid_snapshot_id = 1234567890000 |
| 268 | + |
| 269 | + with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"): |
| 270 | + table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit() |
0 commit comments