Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Run Unit Tests

on:
push:
branches:
- '**'
pull_request:
branches:
- '**'

concurrency:
group: ${{ github.ref }}
cancel-in-progress: true

jobs:
publish-npm:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

- uses: actions/setup-node@v2
with:
node-version: '18.x'
registry-url: 'https://registry.npmjs.org'

- name: Install pnpm
run: npm install -g pnpm
- run: pnpm install
- run: pnpm run test
- run: pnpm run build

33 changes: 33 additions & 0 deletions .github/workflows/e2e-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Run E2E Tests

on:
pull_request:
branches:
- 'main'

concurrency:
group: ${{ github.ref }}-e2e-tests
cancel-in-progress: true

env:
CI: true
DEEPINFRA_API_KEY: ${{ secrets.DEEPINFRA_API_KEY }}

jobs:
e2e-test:
if: false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

- uses: actions/setup-node@v2
with:
node-version: '18.x'
registry-url: 'https://registry.npmjs.org'

- name: Install pnpm
run: npm install -g pnpm
- run: pnpm install
- run: pnpm run test:e2e
- run: pnpm run build

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,4 @@ docs
dist
**misc.ts
docs
.husky
1 change: 0 additions & 1 deletion .husky/pre-commit

This file was deleted.

6 changes: 0 additions & 6 deletions .husky/prepare-commit-msg

This file was deleted.

22 changes: 22 additions & 0 deletions jest.e2e.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/** @type {import('ts-jest').JstConfigWithTsJest} */
module.exports = {
testEnvironment: 'node',
extensionsToTreatAsEsm: [".ts"],
testTimeout: 10000,
coveragePathIgnorePatterns: [
"/node_modules/",
"/examples/",
"/test/"
],
moduleNameMapper: {
"^@/(.*)$": "<rootDir>/src/$1"
},
testMatch: [
"<rootDir>/test/**/*.e2e-test.ts"
],
rootDir: ".",
transform: {
"^.+\\.tsx?$": "ts-jest"
},

};
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"misc": "npx ts-node -r tsconfig-paths/register src/misc.ts",
"prepare": "husky",
"test": "jest --passWithNoTests",
"test:e2e": "jest --config=jest.e2e.config.js --passWithNoTests --runInBand",
"lint": "eslint . --ext .ts --fix",
"prettier": "prettier --write ./src ./test",
"build-docs": "typedoc --out docs src",
Expand Down Expand Up @@ -40,7 +41,8 @@
"dependencies": {
"@swc/core": "^1.4.6",
"@swc/wasm": "^1.4.6",
"axios": "^1.6.7"
"axios": "^1.6.7",
"p-limit": "^5.0.0"
},
"devDependencies": {
"@types/jest": "^29.5.12",
Expand Down
15 changes: 15 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions src/lib/types/common/models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
export const enum ModelTypes {
EMBEDDINGS = 'embeddings',
FILL_MASK = 'fill-mask',
TEXT_GENERATION = 'text-generation',
AUTOMATIC_SPEECH_RECOGNITION = 'automatic-speech-recognition',
TOKEN_CLASSIFICATION = 'token-classification',
TEXT2TEXT_GENERATION = 'text2text-generation',
OBJECT_DETECTION = 'object-detection',
QUESTION_ANSWERING = 'question-answering',
IMAGE_CLASSIFICATION = 'image-classification',
TEXT_TO_IMAGE = 'text-to-image',
ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification',
CUSTOM = 'custom',
TEXT_CLASSIFICATION = 'text-classification',
DREAMBOOTH = 'dreambooth',
}


export interface ModelDefinition {
model_name: string;
type: ModelTypes;
reported_type: ModelTypes;
description: string;
cover_img_url: string;
tags: string[];
pricing: {
cents_per_sec: number;
type: string;
};
max_tokens: number | null;
}

export type ModelDefinitionList = ModelDefinition[];
174 changes: 174 additions & 0 deletions test/e2e/index.e2e-test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import axios from 'axios';
import {ModelDefinition, ModelDefinitionList, ModelTypes} from "@/lib/types/common/models";
import {
AutomaticSpeechRecognition,
Embeddings,
FillMask,
ObjectDetection,
QuestionAnswering,
Sdxl,
TextClassification,
TextGeneration,
TextToImage,
TokenClassification
} from "@/index";

const GET_MODELS = "https://api.deepinfra.com/models/list";
const SDXL_MODEL = "stability-ai/sdxl";
const TEXT_TO_IMAGE_PROMPT = "The quick brown fox jumps over the lazy dog.";
const TEXT_INPUT = "This is a test.";

/*
TODO: Add mock audio file for ASR.
TODO: Add mock image file for object detection.
TODO: Implement p-limit
*/


describe('E2E tests', () => {

let allModels: ModelDefinitionList;

beforeAll(async () => {
const response = await axios.get(GET_MODELS);
allModels = response.data as ModelDefinitionList;
});

it('should have at least one model', () => {
expect(allModels.length).toBeGreaterThan(0);
});
it('Text to image models should infer correctly.', () => {
const textToImageModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT_TO_IMAGE).map(m => m.model_name);
textToImageModels.forEach(async (modelName) => {

if (modelName === SDXL_MODEL) {
const model = new Sdxl();
expect(model).toBeDefined();
await model.generate({input: {prompt: TEXT_TO_IMAGE_PROMPT}}).then(response => {
expect(response).toBeDefined();
});
} else {
const model = new TextToImage(modelName);
expect(model).toBeDefined();

await model.generate({prompt: TEXT_TO_IMAGE_PROMPT}).then(response => {
expect(response).toBeDefined();
});
}

});
});

it('Text classification models should infer correctly.', () => {
const textClassificationModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT_CLASSIFICATION).map(m => m.model_name);
textClassificationModels.forEach(async (modelName) => {
const model = new TextClassification(modelName);
expect(model).toBeDefined();

await model.generate({input: TEXT_INPUT}).then(response => {
expect(response).toBeDefined();
});
});
});

it('Text generation models should infer correctly.', () => {
const textGenerationModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT_GENERATION).map(m => m.model_name);
textGenerationModels.forEach(async (modelName) => {
const model = new TextGeneration(modelName);
expect(model).toBeDefined();

await model.generate({input: TEXT_INPUT}).then(response => {
expect(response).toBeDefined();
});
});
});

it('Fill mask models should infer correctly.', () => {
const fillMaskModels = allModels.filter(model => model.reported_type === ModelTypes.FILL_MASK).map(m => m.model_name);
fillMaskModels.forEach(async (modelName) => {
const model = new FillMask(modelName);
expect(model).toBeDefined();

await model.generate({input: "This is a [MASK]"})
.then(response => {
expect(response).toBeDefined();
});
});
});

it('Embeddings models should infer correctly.', () => {
const embeddingsModels = allModels.filter(model => model.reported_type === ModelTypes.EMBEDDINGS).map(m => m.model_name);
embeddingsModels.forEach(async (modelName) => {
const model = new Embeddings(modelName);
expect(model).toBeDefined();

await model.generate({inputs: [TEXT_INPUT]}).then(response => {
expect(response).toBeDefined();
});
});
});



it('Question answering models should infer correctly.', () => {
const questionAnsweringModels = allModels.filter(model => model.reported_type === ModelTypes.QUESTION_ANSWERING).map(m => m.model_name);
questionAnsweringModels.forEach(async (modelName) => {
const model = new QuestionAnswering(modelName);
expect(model).toBeDefined();

await model.generate({question: TEXT_INPUT, context: TEXT_INPUT})
.then(response => {
expect(response).toBeDefined();
});
});
});

it('Token classification models should infer correctly.', () => {
const tokenClassificationModels = allModels.filter(model => model.reported_type === ModelTypes.TOKEN_CLASSIFICATION).map(m => m.model_name);
tokenClassificationModels.forEach(async (modelName) => {
const model = new TokenClassification(modelName);
expect(model).toBeDefined();

await model.generate({input: TEXT_INPUT}).then(response => {
expect(response).toBeDefined();
});
});
});

it('Text2Text generation models should infer correctly.', () => {
const text2TextGenerationModels = allModels.filter(model => model.reported_type === ModelTypes.TEXT2TEXT_GENERATION).map(m => m.model_name);
text2TextGenerationModels.forEach(async (modelName) => {
const model = new TextGeneration(modelName);
expect(model).toBeDefined();

await model.generate({input: TEXT_INPUT}).then(response => {
expect(response).toBeDefined();
});
});
});

it('Object detection models should infer correctly.', () => {
const objectDetectionModels = allModels.filter(model => model.reported_type === ModelTypes.OBJECT_DETECTION).map(m => m.model_name);
objectDetectionModels.forEach(async (modelName) => {
const model = new ObjectDetection(modelName);
expect(model).toBeDefined();

await model.generate({input: TEXT_INPUT}).then(response => {
expect(response).toBeDefined();
});
});
});

it('Automatic speech recognition models should infer correctly.', () => {
const automaticSpeechRecognitionModels = allModels.filter(model => model.reported_type === ModelTypes.AUTOMATIC_SPEECH_RECOGNITION).map(m => m.model_name);
automaticSpeechRecognitionModels.forEach(async (modelName) => {
const model = new AutomaticSpeechRecognition(modelName);
expect(model).toBeDefined();

await model.generate({input: TEXT_INPUT}).then(response => {
expect(response).toBeDefined();
});
});
});

});