Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions .github/workflows/build-hf-dataset.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,81 @@
name: Build HF Dataset

# Stub on master.
# Spins up a single Hetzner machine, downloads a HuggingFace embeddings dataset,
# splits train/test, computes brute-force kNN ground truth, packages
# vectors.npy + tests.jsonl into a tarball, and uploads it to GCS.

on:
workflow_dispatch:
inputs:
hf_dataset:
description: "HF dataset id, e.g. Qdrant/dbpedia-entities-openai3-text-embedding-3-large-1536-100K"
required: true
default: "Qdrant/gte-multilingual-ads-1M"
output_name:
description: "Output tarball basename"
required: true
default: "ads-gte-multilingual-1M-768-angular"
vector_column:
description: "Embedding column name"
default: "embedding"
default: "gte"

concurrency:
group: hetzner-machines

env:
HCLOUD_TOKEN: ${{ secrets.HCLOUD_TOKEN }}
SERVER_NAME: build-hf-dataset-${{ github.run_id }}
GCS_PATH: gs://ann-filtered-benchmark/datasets/

jobs:
stub:
build:
runs-on: ubuntu-latest
steps:
- run: |
echo "This is the master stub for the Build HF Dataset workflow."
exit 0
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1
- uses: webfactory/ssh-agent@d4b9b8ff72958532804b70bbe600ad43b36d5f2e # v0.8.0
with:
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}

- name: Setup CI
run: bash -x tools/setup_ci.sh

- name: Create server
uses: ./.github/workflows/actions/create-server-with-retry
with:
server_name: ${{ env.SERVER_NAME }}
server_type: ccx43
region: fsn1

- name: Build dataset on remote
env:
GCS_KEY: ${{ secrets.GCS_KEY }}
GCS_SECRET: ${{ secrets.GCS_SECRET }}
run: |
set -euo pipefail
source tools/ssh.sh
IP=$(bash tools/hetzner/get_public_ip.sh "$SERVER_NAME")

scp_with_retry -o StrictHostKeyChecking=no \
scripts/build_hf_dataset.py "root@$IP:/root/"

ssh_with_retry -o StrictHostKeyChecking=no \
-o ServerAliveInterval=30 -o ServerAliveCountMax=20 "root@$IP" \
GCS_KEY="$GCS_KEY" GCS_SECRET="$GCS_SECRET" bash -s <<EOF
set -euxo pipefail
export DEBIAN_FRONTEND=noninteractive
apt-get update
apt-get install -y python3-pip

pip3 install --no-cache-dir datasets numpy faiss-cpu boto3

python3 /root/build_hf_dataset.py \
--hf-dataset "${{ inputs.hf_dataset }}" \
--output-name "${{ inputs.output_name }}" \
--vector-column "${{ inputs.vector_column }}" \
--gcs-uri "$GCS_PATH"
EOF

- name: Teardown server
if: always()
continue-on-error: true
run: bash -x tools/hetzner/remove_server.sh "$SERVER_NAME"
2 changes: 1 addition & 1 deletion dataset_reader/ann_compound_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AnnCompoundReader(JSONReader):
QUERIES_FILE = "tests.jsonl"

def read_vectors(self) -> Iterator[List[float]]:
vectors = np.load(self.path / self.VECTORS_FILE)
vectors = np.load(self.path / self.VECTORS_FILE, mmap_mode="r")
for vector in vectors:
if self.normalize:
vector = vector / np.linalg.norm(vector)
Expand Down
64 changes: 64 additions & 0 deletions datasets/datasets.json
Original file line number Diff line number Diff line change
Expand Up @@ -418,5 +418,69 @@
"type": "tar",
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/laion-1m.tgz",
"path": "laion-1m/laion_1m"
},
{
"name": "ads-gte-multilingual-1M-768-angular",
"vector_size": 768,
"distance": "cosine",
"type": "tar",
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/ads-gte-multilingual-1M-768-angular.tgz",
"path": "ads-gte-multilingual-1M-768-angular/ads-gte-multilingual-1M-768-angular"
},
{
"name": "arxiv-titles-instructorxl-768-angular",
"vector_size": 768,
"distance": "cosine",
"type": "tar",
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/arxiv-titles-instructorxl-768-angular.tgz",
"path": "arxiv-titles-instructorxl-768-angular/arxiv-titles-instructorxl-768-angular"
},
{
"name": "dbpedia-gemini-100K-768-angular",
"vector_size": 768,
"distance": "cosine",
"type": "tar",
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/dbpedia-gemini-100K-768-angular.tgz",
"path": "dbpedia-gemini-100K-768-angular/dbpedia-gemini-100K-768-angular"
},
{
"name": "dbpedia-openai3-large-100K-1536-angular",
"vector_size": 1536,
"distance": "cosine",
"type": "tar",
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/dbpedia-openai3-large-100K-1536-angular.tgz",
"path": "dbpedia-openai3-large-100K-1536-angular/dbpedia-openai3-large-100K-1536-angular"
},
{
"name": "dbpedia-openai3-small-100K-1536-angular",
"vector_size": 1536,
"distance": "cosine",
"type": "tar",
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/dbpedia-openai3-small-100K-1536-angular.tgz",
"path": "dbpedia-openai3-small-100K-1536-angular/dbpedia-openai3-small-100K-1536-angular"
},
{
"name": "sift-128-euclidean",
"vector_size": 128,
"distance": "l2",
"type": "h5",
"path": "sift-128-euclidean/sift-128-euclidean.hdf5",
"link": "http://ann-benchmarks.com/sift-128-euclidean.hdf5"
},
{
"name": "mnist-784-euclidean",
"vector_size": 784,
"distance": "l2",
"type": "h5",
"path": "mnist-784-euclidean/mnist-784-euclidean.hdf5",
"link": "http://ann-benchmarks.com/mnist-784-euclidean.hdf5"
},
{
"name": "fashion-mnist-784-euclidean",
"vector_size": 784,
"distance": "l2",
"type": "h5",
"path": "fashion-mnist-784-euclidean/fashion-mnist-784-euclidean.hdf5",
"link": "http://ann-benchmarks.com/fashion-mnist-784-euclidean.hdf5"
}
]
100 changes: 100 additions & 0 deletions scripts/build_hf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Build an AnnCompoundReader-format dataset from a HuggingFace dataset and upload to GCS."""
import argparse
import json
import os
import tarfile
from pathlib import Path

import boto3
import faiss
import numpy as np
from botocore.config import Config

from datasets import load_dataset


def main():
p = argparse.ArgumentParser()
p.add_argument("--hf-dataset", required=True)
p.add_argument("--output-name", required=True)
p.add_argument("--vector-column", default="embedding")
p.add_argument("--gcs-uri", required=True, help="e.g. gs://bucket/prefix/")
p.add_argument("--split", default="train")
p.add_argument("--num-queries", type=int, default=1000)
p.add_argument("--top-k", type=int, default=100)
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()

print(f"Loading {args.hf_dataset} (split={args.split})...")
ds = (
load_dataset(args.hf_dataset, split=args.split)
.select_columns([args.vector_column])
.with_format("numpy")
)

print(f"Extracting column '{args.vector_column}' from {len(ds)} rows...")
vectors = np.asarray(ds[:][args.vector_column], dtype=np.float32)
n, d = vectors.shape
print(f"Loaded {n} vectors of dimension {d}")

norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms[norms == 0] = 1.0
vectors /= norms

rng = np.random.default_rng(args.seed)
perm = rng.permutation(n)
test = vectors[perm[: args.num_queries]]
train = np.ascontiguousarray(vectors[perm[args.num_queries :]])
del vectors
print(f"Train: {train.shape}, Test: {test.shape}")

print(f"Computing top-{args.top_k} neighbors with FAISS IndexFlatIP...")
index = faiss.IndexFlatIP(d)
index.add(train)
scores, ids = index.search(test, args.top_k)

out_dir = Path("/tmp") / args.output_name
out_dir.mkdir(parents=True, exist_ok=True)
np.save(out_dir / "vectors.npy", train)
with open(out_dir / "tests.jsonl", "w") as f:
for q, qids, qscores in zip(test, ids, scores):
f.write(
json.dumps(
{
"query": q.tolist(),
"conditions": {},
"closest_ids": qids.tolist(),
"closest_scores": qscores.tolist(),
}
)
+ "\n"
)
print(f"Wrote {out_dir}/vectors.npy ({train.nbytes / 1e9:.2f} GB) and tests.jsonl")

tar_path = Path("/tmp") / f"{args.output_name}.tgz"
print(f"Creating {tar_path}...")
with tarfile.open(tar_path, "w:gz") as tar:
tar.add(out_dir / "vectors.npy", arcname="vectors.npy")
tar.add(out_dir / "tests.jsonl", arcname="tests.jsonl")

if not args.gcs_uri.startswith("gs://"):
raise ValueError("--gcs-uri must start with gs://")
bucket_name, _, prefix = args.gcs_uri[len("gs://") :].partition("/")
blob_name = f"{prefix.rstrip('/')}/{args.output_name}.tgz".lstrip("/")
print(f"Uploading to gs://{bucket_name}/{blob_name}...")
boto3.client(
"s3",
endpoint_url="https://storage.googleapis.com",
aws_access_key_id=os.environ["GCS_KEY"],
aws_secret_access_key=os.environ["GCS_SECRET"],
config=Config(
request_checksum_calculation="when_required",
response_checksum_validation="when_required",
),
).upload_file(str(tar_path), bucket_name, blob_name)
print("Done.")


if __name__ == "__main__":
main()
Loading