Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,37 @@ import { Map, Set } from 'immutable'
import type { DataType, Network, TaskProvider } from "@epfml/discojs";
import { defaultTasks } from '@epfml/discojs'

interface BenchmarkArguments {
type AggregationStrategy = "mean" | "byzantine" | "secure";

function parseAggregator(raw: string): AggregationStrategy{
if (raw === "mean" || raw == "byzantine" || raw == "secure")
return raw;
else
throw new Error(`Aggregator ${raw} is not supported.`);
}

export interface BenchmarkArguments {
provider: TaskProvider<DataType, Network>;
testID: string
numberOfUsers: number
epochs: number
roundDuration: number
batchSize: number
validationSplit: number

// DP
epsilon?: number
delta?: number
dpDefaultClippingRadius?: number
// Aggregator
aggregator: AggregationStrategy
// Byzantine aggregator
clippingRadius?: number
maxIterations?: number
beta?: number
// Secure aggregator
maxShareValue?: number

save: boolean
host: URL
}
Expand All @@ -27,15 +48,13 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'

const unsafeArgs = parse<BenchmarkUnsafeArguments>(
{
testID: { type: String, alias: 'i', description: 'ID of the testcase' },
task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
epsilon: { type: Number, alias: 'n', description: 'Privacy budget', optional: true, defaultValue: undefined},
delta: { type: Number, alias: 'd', description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
dpDefaultClippingRadius: {type: Number, alias: 'f', description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
Expand All @@ -44,6 +63,22 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
defaultValue: new URL("http://localhost:8080"),
},

// Aggregator
aggregator: { type: parseAggregator, description: 'Type of weight aggregator', defaultValue: 'mean' },

// Byzantine aggregator
clippingRadius: { type: Number, description: "Clipping radius for centered clipping", optional: true },
maxIterations: { type: Number, description: "Maximum centered clipping iterations", optional: true },
beta: { type: Number, description: "Momentum coefficient to smooth the aggregation over multiple rounds", optional: true },

// Secure aggregator
maxShareValue: { type: Number, description: "Maximum absolute value over all the weights", optional: true },

// Differential Privacy
epsilon: { type: Number, description: 'Privacy budget', optional: true, defaultValue: undefined},
delta: { type: Number, description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
dpDefaultClippingRadius: {type: Number, description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},

help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }
},
{
Expand Down Expand Up @@ -88,6 +123,44 @@ export const args: BenchmarkArguments = {
task.trainingInformation.epochs = unsafeArgs.epochs;
task.trainingInformation.validationSplit = unsafeArgs.validationSplit;

const {aggregator, clippingRadius, maxIterations, beta, maxShareValue} = unsafeArgs;

// For aggregators
if (aggregator !== undefined)
task.trainingInformation.aggregationStrategy = aggregator;

// For byzantine aggregator
if (
clippingRadius !== undefined &&
maxIterations !== undefined &&
beta !== undefined
){
if (task.trainingInformation.scheme === "local")
throw new Error("Byzantine aggregator is not supported for local training");
if (task.trainingInformation.aggregationStrategy !== "byzantine")
throw new Error("Byzantine parameters can be set only when aggregationStrategy is byzantine");

task.trainingInformation.privacy = {
...task.trainingInformation.privacy,
byzantineFaultTolerance: {
clippingRadius,
maxIterations,
beta,
},
};
}

// For secure aggregator
if (maxShareValue !== undefined){

if (task.trainingInformation.scheme !== "decentralized")
throw new Error("Secure aggation is only supported for decentralized laerning")
if (task.trainingInformation.aggregationStrategy !== "secure")
throw new Error("maxShareValue can be set when aggregationStrategy is secure");

task.trainingInformation.maxShareValue = maxShareValue;
}

// For DP
const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs;

Expand All @@ -102,7 +175,7 @@ export const args: BenchmarkArguments = {
const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1;

// for the case where privacy parameters are not defined in the default tasks
task.trainingInformation.privacy ??= {}
task.trainingInformation.privacy ??= {};
task.trainingInformation.privacy.differentialPrivacy = {
clippingRadius: defaultRadius,
epsilon: epsilon,
Expand Down
35 changes: 29 additions & 6 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ import "@tensorflow/tfjs-node"

import { List, Range } from 'immutable'
import fs from 'node:fs/promises'
import path from "node:path";

import type {
Dataset,
DataFormat,
DataType,
RoundLogs,
SummaryLogs,
Task,
TaskProvider,
Network,
Expand All @@ -17,6 +18,8 @@ import { Disco, aggregator as aggregators, client as clients } from '@epfml/disc

import { getTaskData } from './data.js'
import { args } from './args.js'
import { makeUserLogFile } from "./user_log.js";
import type { UserLogFile } from "./user_log.js";

// Array.fromAsync not yet widely used (2024)
async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
Expand All @@ -29,16 +32,32 @@ async function runUser<D extends DataType, N extends Network>(
task: Task<D, N>,
url: URL,
data: Dataset<DataFormat.Raw[D]>,
): Promise<List<RoundLogs>> {
userIndex: number,
numberOfUsers: number,
): Promise<List<SummaryLogs>> {
// cast as typescript isn't good with generics
const trainingScheme = task.trainingInformation.scheme as N
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = new Disco(task, client, { scheme: trainingScheme });

const logs = List(await arrayFromAsync(disco.trainByRound(data)));
const logs = List(await arrayFromAsync(disco.trainSummary(data)));
await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish

// saving per-user logs
if (args.save) {
const dir = path.join(".", `${args.testID}`, `${task.id}`);
await fs.mkdir(dir, { recursive: true });

const filePath = path.join(dir, `client${userIndex}_local_log.json`);

const userLog: UserLogFile = makeUserLogFile(task, numberOfUsers, userIndex, client.ownId, logs.toArray());

await fs.writeFile(filePath, JSON.stringify(userLog, null, 2));
}

await disco.close();

return logs;
}

Expand All @@ -47,19 +66,23 @@ async function main<D extends DataType, N extends Network>(
numberOfUsers: number,
): Promise<void> {
const task = await provider.getTask();
console.log(`Test ID: ${args.testID}`)
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
)
const logs = await Promise.all(
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))
dataSplits.map((data, i) => runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>, i, numberOfUsers))
)

if (args.save) {
const fileName = `${task.id}_${numberOfUsers}users.csv`;
await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
const dir = path.join(".", `${args.testID}`, `${task.id}`);
await fs.mkdir(dir, { recursive: true });

const filePath = path.join(dir, `${task.id}_${numberOfUsers}users.json`);
await fs.writeFile(filePath, JSON.stringify(logs, null, 2));
}
}

Expand Down
95 changes: 32 additions & 63 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import path from "node:path";
import { promises as fs } from "fs";
import { Dataset, processing, defaultTasks } from "@epfml/discojs";
import { Dataset, processing } from "@epfml/discojs";
import {
DataFormat,
DataType,
Expand All @@ -20,7 +19,9 @@ async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise

const combinded = adults.chain(childs);

return combinded.filter((_, i) => i % totalClient === userIdx);
const sharded = combinded.filter((_, i) => i % totalClient === userIdx);

return sharded;
}

async function loadLusCovidData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
Expand Down Expand Up @@ -66,62 +67,28 @@ function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
});
}

async function loadExtCifar10(userIdx: number): Promise<Dataset<[Image, string]>> {
const CIFAR10_LABELS = Array.from(await defaultTasks.cifar10.getTask().then(t => t.trainingInformation.LABEL_LIST));
const folder = path.join("..", "datasets", "extended_cifar10");
const clientFolder = path.join(folder, `client_${userIdx}`);

return new Dataset(async function*(){
const entries = await fs.readdir(clientFolder, {withFileTypes: true});

const items = entries
.flatMap((e) => {
const m = e.name.match(
/^image_(\d+)_label_(\d+)\.png$/i
);
if (m === null) return [];
const labelIdx = Number.parseInt(m[2], 10);

if(labelIdx >= CIFAR10_LABELS.length)
throw new Error(`${e.name}: too big label index`);

return {
name: e.name,
label: CIFAR10_LABELS[labelIdx],
};
})
.filter((x) => x !== null)

for (const {name, label} of items){
const filePath = path.join(clientFolder, name);
const image = await loadImage(filePath);
yield [image, label] as const;
}
})
}

function loadMnistData(split: number): Dataset<DataFormat.Raw["image"]>{
const folder = path.join("..", "datasets", "mnist", `${split + 1}`);
function loadData(dataName: string, split: number): Dataset<DataFormat.Raw["image"]>{
const folder = path.join("..", "datasets", `${dataName}`, `client_${split}`);
return loadCSV(path.join(folder, "labels.csv"))
.map(
(row) =>
[
processing.extractColumn(row, "filename"),
processing.extractColumn(row, "label"),
] as const,
(row) => [
processing.extractColumn(row, "filename"),
processing.extractColumn(row, "label"),
] as const,
)
.map(async ([filename, label]) => {
try {
const image = await Promise.any(
["png", "jpg", "jpeg"].map((ext) =>
loadImage(path.join(folder, `${filename}.${ext}`)),
),
);
return [image, label];
} catch {
throw Error(`${filename} not found in ${folder}`);
.map(
async ([filename, label]) => {
try {
const img = await Promise.any(
["png", "jpg", "jpeg"].map((ext) =>
loadImage(path.join(folder, `${filename}.${ext}`)))
);
return [img, label]
} catch {
throw Error(`${filename} not found in ${folder}`);
}
}
});
);
}

export async function getTaskData<D extends DataType>(
Expand All @@ -130,25 +97,27 @@ export async function getTaskData<D extends DataType>(
totalClient: number
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (taskID) {
case "simple_face":
case "simple_face": // remove
return (await loadSimpleFaceData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
case "titanic":
case "titanic_decentralized":
const titanicData = loadCSV(
path.join("..", "datasets", "titanic_train.csv"),
) as Dataset<DataFormat.Raw[D]>;
return titanicData.filter((_, i) => i % totalClient === userIdx);
case "cifar10":
return (
await loadImagesInDir(path.join("..", "datasets", "CIFAR10"))
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
return loadData("cifar10-agent", userIdx) as Dataset<DataFormat.Raw[D]>;
case "cifar10_federated_simple_model":
case "cifar10_simple_model":
return loadData("cifar10_ext", userIdx) as Dataset<DataFormat.Raw[D]>;
case "lus_covid":
case "lus_covid_decentralized":
return (await loadLusCovidData(userIdx, totalClient)) as Dataset<DataFormat.Raw[D]>;
case "tinder_dog":
case "tinder_dog": // remove
return loadTinderDogData(userIdx) as Dataset<DataFormat.Raw[D]>;
case "extended_cifar10":
return (await loadExtCifar10(userIdx)) as Dataset<DataFormat.Raw[D]>;
case "mnist_federated":
case "mnist":
return loadMnistData(userIdx) as Dataset<DataFormat.Raw[D]>;
return loadData("mnist", userIdx) as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${taskID} not implemented.`);
}
Expand Down
Loading