Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions discojs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export {
EpochLogs,
Tokenizer,
ValidationMetrics,
ModelMetadata,
} from "./models/index.js";
export * as models from './models/index.js'

Expand Down
2 changes: 1 addition & 1 deletion discojs/src/models/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export { Model } from './model.js'
export { Model, ModelMetadata } from './model.js'
export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js";
export { Tokenizer } from "./tokenizer.js";

Expand Down
8 changes: 8 additions & 0 deletions discojs/src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import type {
} from "../index.js";

import type { BatchLogs, EpochLogs } from "./logs.js";
import type { StandardizationStats } from "../processing/tabular.js";

export type ModelMetadata = {
tabularStandardization?: StandardizationStats;
};

/**
* Trainable predictor
Expand All @@ -21,6 +26,9 @@ export abstract class Model<D extends DataType> implements Disposable {
/** Set training state */
abstract set weights(ws: WeightsContainer);

/** Optional metadata for tabular task data standardization */
metadata?: ModelMetadata;

/**
* Improve predictor
*
Expand Down
11 changes: 8 additions & 3 deletions discojs/src/models/tfjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ import {
import { BatchLogs } from './index.js'
import { Model } from './index.js'
import { EpochLogs } from './logs.js'
import { ModelMetadata } from "./model.js";

type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts];
type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts, ModelMetadata?];

/** TensorFlow JavaScript model with standard training */
export class TFJS<D extends "image" | "tabular"> extends Model<D> {
/** Wrap the given trainable model */
constructor (
public readonly datatype: D,
private readonly model: tf.LayersModel
private readonly model: tf.LayersModel,
metadata?: ModelMetadata,
) {
super()
this.metadata = metadata;

if (model.loss === undefined) {
throw new Error('TFJS models need to be compiled to be used')
Expand Down Expand Up @@ -176,12 +179,14 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
static async deserialize<D extends "image" | "tabular">([
datatype,
artifacts,
metadata
]: Serialized<D>): Promise<TFJS<D>> {
return new this(
datatype,
await tf.loadLayersModel({
load: () => Promise.resolve(artifacts),
}),
metadata
);
}

Expand All @@ -204,7 +209,7 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
includeOptimizer: true // keep model compiled
})

return [this.datatype, await ret]
return [this.datatype, await ret, this.metadata]
}

[Symbol.dispose](): void{
Expand Down
17 changes: 15 additions & 2 deletions discojs/src/processing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type {
Tabular,
Task,
Network,
ModelMetadata,
} from "../index.js";

import * as processing from "./index.js";
Expand All @@ -19,6 +20,7 @@ export * from "./tabular.js";
export function preprocess<D extends DataType, N extends Network>(
task: Task<D, N>,
dataset: Dataset<DataFormat.Raw[D]>,
metadata?: ModelMetadata,
): Dataset<DataFormat.ModelEncoded[D]> {
switch (task.dataType) {
case "image": {
Expand All @@ -37,12 +39,17 @@ export function preprocess<D extends DataType, N extends Network>(
// cast as typescript doesn't reduce generic type
const d = dataset as Dataset<DataFormat.Raw["tabular"]>;
const { inputColumns, outputColumn } = task.trainingInformation;
const stats = metadata?.tabularStandardization;

return d.map((row) => {
const output = processing.extractColumn(row, outputColumn);

const inputs = stats
? List(processing.standardizeRow(row, inputColumns, stats))
: extractToNumbers(inputColumns, row);

return [
extractToNumbers(inputColumns, row),
inputs,
// TODO sanitization doesn't care about column distribution
output !== "" ? processing.convertToNumber(output) : 0,
];
Expand All @@ -68,6 +75,7 @@ export function preprocess<D extends DataType, N extends Network>(
export function preprocessWithoutLabel<D extends DataType>(
task: Task<D, Network>,
dataset: Dataset<DataFormat.RawWithoutLabel[D]>,
metadata?: ModelMetadata,
): Dataset<DataFormat.ModelEncoded[D][0]> {
switch (task.dataType) {
case "image": {
Expand All @@ -85,8 +93,13 @@ export function preprocessWithoutLabel<D extends DataType>(
// cast as typescript doesn't reduce generic type
const d = dataset as Dataset<DataFormat.Raw["tabular"]>;
const { inputColumns } = task.trainingInformation;
const stats = metadata?.tabularStandardization;

return d.map((row) => extractToNumbers(inputColumns, row));
return d.map((row) =>
stats
? List(processing.standardizeRow(row, inputColumns, stats))
: extractToNumbers(inputColumns, row)
);
}
case "text": {
// cast as typescript doesn't reduce generic type
Expand Down
65 changes: 65 additions & 0 deletions discojs/src/processing/tabular.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { List } from "immutable";

export type StandardizationStats = {
means: Record<string, number>;
stds: Record<string, number>;
};

/**
* Convert a string to a number
*
Expand Down Expand Up @@ -38,3 +43,63 @@ export function indexInList(
if (ret === -1) throw new Error(`${element} not found in list`);
return ret;
}

/**
* Return the mean, std value of each column
*/
export function computeStandardizationStats(
rows: Array<Partial<Record<string, string>>>,
columns: Array<string>,
): StandardizationStats{
const means: Record<string, number> = {};
const stds: Record<string, number> = {};

for (const col of columns){
const values = rows.map((row)=> {
const rawValue = extractColumn(row, col);
return convertToNumber(rawValue !== "" ? rawValue : "0");
});
const mean = values.reduce((a, b)=> a+b, 0) / values.length;
const variance = values.reduce((acc, val) => acc + (val-mean)**2, 0) / values.length;

const std = Math.sqrt(variance);

means[col] = mean;
stds[col] = std;
}

return {means, stds};
}

/**
* Apply standardization for a single value
*/
export function standardizeValue(
value: number,
mean: number,
std: number,
): number{
if (std == 0) return 0; // avoid divide by 0
return (value - mean) / std;
}

/**
* Apply standardization for a row
*
* standardization function is called for each row in dataset
*/
export function standardizeRow(
row: Partial<Record<string, string>>,
columns: Array<string>,
stats: StandardizationStats,
): Array<number>{
return columns.map((col) => {
const rawValue = extractColumn(row, col)
// Handle cases where the dataset contains empty strings.
// This only occurs in test cases, as empty strings are not allowed in the web app.
const value = convertToNumber(rawValue !== "" ? rawValue : "0");
const mean = stats.means[col];
const std = stats.stds[col];
return standardizeValue(value, mean, std);
})
}
10 changes: 6 additions & 4 deletions discojs/src/serialization/model.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type tf from '@tensorflow/tfjs'

import type { DataType, Model } from '../index.js'
import type { DataType, Model, ModelMetadata } from '../index.js'
import { models, serialization } from '../index.js'
import { GPTConfig } from '../models/index.js'

Expand Down Expand Up @@ -41,11 +41,11 @@ export async function decode(encoded: Encoded): Promise<Model<DataType>> {
const rawModel = raw[1] as unknown
switch (type) {
case Type.TFJS: {
if (raw.length !== 3)
if (raw.length !== 3 && raw.length !== 4)
throw new Error(
"invalid TFJS model encoding: should be an array of length 3",
"invalid TFJS model encoding: should be an array of length 3 or 4",
);
const [rawDatatype, rawModel] = raw.slice(1) as unknown[];
const [rawDatatype, rawModel, rawMetadata] = raw.slice(1) as unknown[];

let datatype;
switch (rawDatatype) {
Expand All @@ -63,6 +63,8 @@ export async function decode(encoded: Encoded): Promise<Model<DataType>> {
datatype,
// TODO totally unsafe casting
rawModel as tf.io.ModelArtifacts,
// metadata for tabular task standardization
rawMetadata as ModelMetadata,
]);
}
case Type.GPT: {
Expand Down
85 changes: 71 additions & 14 deletions discojs/src/training/disco.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ export class Disco<D extends DataType, N extends Network> extends EventEmitter<{
> {
this.#logger.success("Training started");

const [trainingDataset, validationDataset] =
await this.#preprocessSplitAndBatch(dataset);

// the client fetches the latest weights upon connection
// TODO unsafe cast
this.trainer.model = (await this.#client.connect()) as Model<D>;

const [trainingDataset, validationDataset] =
await this.#preprocessSplitAndBatch(dataset);

for await (const [round, epochs] of enumerate(
this.trainer.train(trainingDataset, validationDataset),
)) {
Expand Down Expand Up @@ -213,21 +213,78 @@ export class Disco<D extends DataType, N extends Network> extends EventEmitter<{
> {
const { batchSize, validationSplit } = this.#task.trainingInformation;

let preprocessed = processing.preprocess(this.#task, dataset);
if (validationSplit === 0){
if (this.#task.dataType === "tabular"){
const rows = await arrayFromAsync(dataset as Dataset<DataFormat.Raw["tabular"]>);
const inputColumns = this.#task.trainingInformation.inputColumns;

const stats = processing.computeStandardizationStats(rows, inputColumns);
this.trainer.model.metadata = {
tabularStandardization: stats,
};

const preprocessed = processing.preprocess(
this.#task,
dataset,
this.trainer.model.metadata,
);
return [preprocessed.batch(batchSize).cached(), undefined];
}
// If task datatype is not tabular
let preprocessed = processing.preprocess(this.#task, dataset);

preprocessed = (
this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessed))
: preprocessed
)
return [preprocessed.batch(batchSize).cached(), undefined];
}

// If training/validation splitting ratio is defined
const [training, validation] = dataset.split(validationSplit);

if (this.#task.dataType == "tabular"){
const trainingRows = await arrayFromAsync(training as Dataset<DataFormat.Raw["tabular"]>);
const inputColumns = this.#task.trainingInformation.inputColumns;
const stats = processing.computeStandardizationStats(trainingRows, inputColumns);

this.trainer.model.metadata = {
tabularStandardization: stats,
};

let preprocessedTraining = processing.preprocess(this.#task, training, this.trainer.model.metadata);
let preprocessedValidation = processing.preprocess(this.#task, validation, this.trainer.model.metadata);
preprocessedTraining = this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessedTraining))
: preprocessedTraining;

preprocessedValidation = this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessedValidation))
: preprocessedValidation;

return [
preprocessedTraining.batch(batchSize).cached(),
preprocessedValidation.batch(batchSize).cached(),
];
}

// if task datatype is not tabular
let preprocessedTraining = processing.preprocess(this.#task, training);
let preprocessedValidation = processing.preprocess(this.#task, validation);

preprocessed = (
this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessed))
: preprocessed
)
if (validationSplit === 0) return [preprocessed.batch(batchSize).cached(), undefined];
preprocessedTraining = this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessedTraining))
: preprocessedTraining;

const [training, validation] = preprocessed.split(validationSplit);
preprocessedValidation = this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessedValidation))
: preprocessedValidation;

return [
training.batch(batchSize).cached(),
validation.batch(batchSize).cached(),
];
preprocessedTraining.batch(batchSize).cached(),
preprocessedValidation.batch(batchSize).cached(),
];
}
}

Expand Down
29 changes: 22 additions & 7 deletions webapp/src/components/dataset_input/validate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,30 @@ import { Range, Set } from "immutable";

import type { LabeledDataset } from "./types";

function isNaNValue(value: string | undefined): boolean{
if (value === undefined)
return true;

const trimmed = value.trim();
return trimmed === "" || trimmed.toLowerCase() === "nan";
}

export async function tabular(
wantedColumns: Set<string>,
dataset: LabeledDataset["tabular"],
): Promise<void> {
for await (const [columns, i] of dataset
.map((row) => Set(Object.keys(row)))
.zip(Range(1, Number.POSITIVE_INFINITY)))
if (!columns.isSuperset(wantedColumns))
throw new Error(
`row ${i} is missing columns ${wantedColumns.subtract(columns).join(", ")}`,
);
for await (const [row, i] of dataset
.zip(Range(1, Number.POSITIVE_INFINITY))){
const columns = Set(Object.keys(row));

if (!columns.isSuperset(wantedColumns))
throw new Error(
`row ${i} is missing columns ${wantedColumns.subtract(columns).join(", ")}`,
);

for (const col of wantedColumns){
if (isNaNValue(row[col]))
throw new Error(`row ${i} column "${col}" contains NaN`);
}
}
}