diff --git a/src/app.controller.ts b/src/app.controller.ts index 1e1cf90..dcfe056 100644 --- a/src/app.controller.ts +++ b/src/app.controller.ts @@ -19,7 +19,7 @@ export class ChatController { @Get() getHello(): string { - return this.appService.getHello(); + return this.appService.getVersion(); } @Post('completions') diff --git a/src/app.service.ts b/src/app.service.ts index 15fdfb0..46d2635 100644 --- a/src/app.service.ts +++ b/src/app.service.ts @@ -1,10 +1,11 @@ import { HttpService } from '@nestjs/axios'; -import { Header, Injectable } from '@nestjs/common'; +import { Header, Injectable, Logger } from '@nestjs/common'; import { models } from './models'; import { firstValueFrom } from 'rxjs'; @Injectable() export class AppService { + private readonly logger = new Logger(AppService.name); constructor(private readonly httpService: HttpService) {} @Header('Content-Type', 'application/json') @@ -12,17 +13,19 @@ export class AppService { return models; } - getHello(): string { - return 'Hello World! 1.8'; + getVersion(): string { + return 'Hello World! 2.0'; } async getCompletions( endpoint: string, - deployment_id: string, + mapping: string, azureApiKey: string, body: any, stream: boolean, ) { + const deployment_id = this.getDeploymentId(mapping, body['model']); + this.logger.debug(`deployment_id: ${deployment_id}`); const url = `${endpoint}/openai/deployments/${deployment_id}/chat/completions?api-version=2023-03-15-preview`; const headers = { 'api-key': azureApiKey, @@ -35,4 +38,25 @@ export class AppService { const ret = this.httpService.post(url, body, config); return await firstValueFrom(ret); } + private getDeploymentId(mapping: string, model: string): string { + this.logger.debug(`mapping: ${mapping}, model: ${model}`); + if (mapping.includes(',')) { + let defaultDeploymentId = ''; + const modelMapping = mapping + .split(',') + .reduce((acc: Record, pair: string) => { + const [key, value] = pair.split('|'); + if (defaultDeploymentId === '') defaultDeploymentId = value; + acc[key] = value; + return acc; + }, {}); + if (!model) { + return defaultDeploymentId; + } + const deploymentId = modelMapping[model]; + return deploymentId || defaultDeploymentId; + } else { + return mapping; + } + } } diff --git a/src/main.ts b/src/main.ts index fd30a90..4226ba2 100644 --- a/src/main.ts +++ b/src/main.ts @@ -2,7 +2,9 @@ import { NestFactory } from '@nestjs/core'; import { AppModule } from './app.module'; async function bootstrap() { - const app = await NestFactory.create(AppModule); + const app = await NestFactory.create(AppModule, { + logger: ['log'], + }); app.setGlobalPrefix('v1'); await app.listen(3000); }