Phase 1 complete
This commit is contained in:
42
apps/backend/generate-openapi.ts
Normal file
42
apps/backend/generate-openapi.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import { NestFactory } from '@nestjs/core';
|
||||
import { SwaggerModule, DocumentBuilder } from '@nestjs/swagger';
|
||||
import { AppModule } from './src/app.module';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
|
||||
async function generateOpenApi() {
|
||||
// Create app without starting the server
|
||||
const app = await NestFactory.create(AppModule, {
|
||||
logger: false, // Suppress logs during generation
|
||||
});
|
||||
|
||||
app.setGlobalPrefix('api');
|
||||
|
||||
const config = new DocumentBuilder()
|
||||
.setTitle('DreamChat API')
|
||||
.setDescription('The DreamChat API documentation')
|
||||
.setVersion('1.0.0')
|
||||
.addBearerAuth()
|
||||
.build();
|
||||
|
||||
const document = SwaggerModule.createDocument(app, config);
|
||||
|
||||
// Ensure the output directory exists
|
||||
const outputDir = path.join(__dirname, '..', '..', 'openapi');
|
||||
if (!fs.existsSync(outputDir)) {
|
||||
fs.mkdirSync(outputDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Write the spec file
|
||||
const outputPath = path.join(outputDir, 'openapi.json');
|
||||
fs.writeFileSync(outputPath, JSON.stringify(document, null, 2));
|
||||
console.log(`📄 OpenAPI spec written to: ${outputPath}`);
|
||||
|
||||
await app.close();
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
generateOpenApi().catch((err) => {
|
||||
console.error('Failed to generate OpenAPI spec:', err);
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -11,22 +11,29 @@
|
||||
"db:migrate": "prisma migrate deploy",
|
||||
"db:generate": "prisma generate",
|
||||
"db:seed": "prisma db seed",
|
||||
"openapi:generate": "node dist/generate-openapi.js",
|
||||
"clean": "rm -r dist"
|
||||
},
|
||||
"dependencies": {
|
||||
"@dreamchat/shared": "workspace:*",
|
||||
"@nestjs/common": "^11.1.14",
|
||||
"@nestjs/core": "^11.1.14",
|
||||
"@nestjs/jwt": "^11.0.0",
|
||||
"@nestjs/passport": "^11.0.5",
|
||||
"@nestjs/platform-express": "^11.1.14",
|
||||
"@nestjs/platform-socket.io": "^11.1.14",
|
||||
"@nestjs/swagger": "^11.0.0",
|
||||
"@nestjs/websockets": "^11.1.14",
|
||||
"@prisma/adapter-pg": "^7.4.1",
|
||||
"@prisma/client": "^7.4.1",
|
||||
"@types/keycloak-connect": "^7.0.0",
|
||||
"@xenova/transformers": "^2.15.0",
|
||||
"bcrypt": "^6.0.0",
|
||||
"class-transformer": "^0.5.1",
|
||||
"class-validator": "^0.14.0",
|
||||
"dotenv": "^17.3.1",
|
||||
"jsonwebtoken": "^9.0.0",
|
||||
"keycloak-connect": "^26.1.1",
|
||||
"passport": "^0.7.0",
|
||||
"passport-jwt": "^4.0.0",
|
||||
"passport-local": "^1.0.0",
|
||||
@@ -40,6 +47,7 @@
|
||||
"@nestjs/testing": "^11.1.14",
|
||||
"@types/bcrypt": "^6.0.0",
|
||||
"@types/jsonwebtoken": "^9.0.0",
|
||||
"@types/multer": "^1.4.12",
|
||||
"@types/node": "^24.10.13",
|
||||
"@types/passport-jwt": "^4.0.0",
|
||||
"@types/passport-local": "^1.0.0",
|
||||
|
||||
267
apps/backend/prisma/migrations/20260224085801_init/migration.sql
Normal file
267
apps/backend/prisma/migrations/20260224085801_init/migration.sql
Normal file
@@ -0,0 +1,267 @@
|
||||
-- Enable pgvector extension
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ImportSourceType" AS ENUM ('file', 'url', 'manual');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ImportStatus" AS ENUM ('pending', 'processing', 'completed', 'failed');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "MessageRole" AS ENUM ('user', 'assistant', 'system');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "UserRole" AS ENUM ('USER', 'ADMIN');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "MemoryType" AS ENUM ('conversation', 'character');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "Character" (
|
||||
"id" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"avatarUrl" TEXT,
|
||||
"personalityPrompt" TEXT NOT NULL,
|
||||
"attributes" JSONB NOT NULL DEFAULT '{}',
|
||||
"config" JSONB NOT NULL DEFAULT '{}',
|
||||
"isPublic" BOOLEAN NOT NULL DEFAULT false,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "Character_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "CharacterKnowledge" (
|
||||
"id" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"sourceType" "ImportSourceType" NOT NULL,
|
||||
"sourceName" TEXT NOT NULL,
|
||||
"mimeType" TEXT,
|
||||
"fileSize" BIGINT,
|
||||
"rawContent" TEXT,
|
||||
"status" "ImportStatus" NOT NULL DEFAULT 'pending',
|
||||
"processingInfo" JSONB,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"characterId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "CharacterKnowledge_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "Conversation" (
|
||||
"id" TEXT NOT NULL,
|
||||
"title" TEXT,
|
||||
"messageCount" INTEGER NOT NULL DEFAULT 0,
|
||||
"totalTokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"settings" JSONB NOT NULL DEFAULT '{}',
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
"characterId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "Conversation_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ConversationParticipant" (
|
||||
"id" TEXT NOT NULL,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"autoRespond" BOOLEAN NOT NULL DEFAULT true,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"conversationId" TEXT NOT NULL,
|
||||
"characterId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "ConversationParticipant_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ImportDocument" (
|
||||
"id" TEXT NOT NULL,
|
||||
"sourceType" "ImportSourceType" NOT NULL,
|
||||
"sourceName" TEXT NOT NULL,
|
||||
"mimeType" TEXT,
|
||||
"fileSize" BIGINT,
|
||||
"content" TEXT,
|
||||
"status" "ImportStatus" NOT NULL DEFAULT 'pending',
|
||||
"errorMessage" TEXT,
|
||||
"metadata" JSONB,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "ImportDocument_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "Message" (
|
||||
"id" TEXT NOT NULL,
|
||||
"role" "MessageRole" NOT NULL,
|
||||
"content" TEXT NOT NULL,
|
||||
"tokensUsed" INTEGER,
|
||||
"model" TEXT,
|
||||
"metadata" JSONB,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"conversationId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "Message_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "StoryBranch" (
|
||||
"id" TEXT NOT NULL,
|
||||
"title" TEXT,
|
||||
"content" TEXT NOT NULL,
|
||||
"userDirection" TEXT NOT NULL,
|
||||
"generationParams" JSONB,
|
||||
"depth" INTEGER NOT NULL DEFAULT 0,
|
||||
"branchOrder" INTEGER NOT NULL DEFAULT 0,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"conversationId" TEXT NOT NULL,
|
||||
"parentId" TEXT,
|
||||
|
||||
CONSTRAINT "StoryBranch_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "User" (
|
||||
"id" TEXT NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"username" TEXT NOT NULL,
|
||||
"passwordHash" TEXT,
|
||||
"keycloakSub" TEXT,
|
||||
"role" "UserRole" NOT NULL DEFAULT 'USER',
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "VectorMemory" (
|
||||
"id" TEXT NOT NULL,
|
||||
"content" TEXT NOT NULL,
|
||||
"embedding" vector,
|
||||
"memoryType" "MemoryType" NOT NULL DEFAULT 'conversation',
|
||||
"metadata" JSONB,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"conversationId" TEXT,
|
||||
"characterId" TEXT,
|
||||
"knowledgeId" TEXT,
|
||||
|
||||
CONSTRAINT "VectorMemory_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Character_userId_idx" ON "Character"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Character_name_idx" ON "Character"("name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "CharacterKnowledge_characterId_idx" ON "CharacterKnowledge"("characterId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "CharacterKnowledge_status_idx" ON "CharacterKnowledge"("status");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Conversation_userId_idx" ON "Conversation"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Conversation_characterId_idx" ON "Conversation"("characterId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Conversation_createdAt_idx" ON "Conversation"("createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ConversationParticipant_conversationId_idx" ON "ConversationParticipant"("conversationId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "ConversationParticipant_conversationId_characterId_key" ON "ConversationParticipant"("conversationId", "characterId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ImportDocument_userId_idx" ON "ImportDocument"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ImportDocument_status_idx" ON "ImportDocument"("status");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Message_conversationId_idx" ON "Message"("conversationId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Message_createdAt_idx" ON "Message"("createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "Message_conversationId_createdAt_idx" ON "Message"("conversationId", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoryBranch_conversationId_idx" ON "StoryBranch"("conversationId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoryBranch_parentId_idx" ON "StoryBranch"("parentId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "User_username_key" ON "User"("username");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "User_keycloakSub_key" ON "User"("keycloakSub");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_email_idx" ON "User"("email");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_keycloakSub_idx" ON "User"("keycloakSub");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "VectorMemory_conversationId_idx" ON "VectorMemory"("conversationId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "VectorMemory_characterId_idx" ON "VectorMemory"("characterId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "VectorMemory_knowledgeId_idx" ON "VectorMemory"("knowledgeId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "VectorMemory_memoryType_idx" ON "VectorMemory"("memoryType");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Character" ADD CONSTRAINT "Character_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "CharacterKnowledge" ADD CONSTRAINT "CharacterKnowledge_characterId_fkey" FOREIGN KEY ("characterId") REFERENCES "Character"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Conversation" ADD CONSTRAINT "Conversation_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Conversation" ADD CONSTRAINT "Conversation_characterId_fkey" FOREIGN KEY ("characterId") REFERENCES "Character"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ConversationParticipant" ADD CONSTRAINT "ConversationParticipant_conversationId_fkey" FOREIGN KEY ("conversationId") REFERENCES "Conversation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ImportDocument" ADD CONSTRAINT "ImportDocument_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "Message" ADD CONSTRAINT "Message_conversationId_fkey" FOREIGN KEY ("conversationId") REFERENCES "Conversation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoryBranch" ADD CONSTRAINT "StoryBranch_conversationId_fkey" FOREIGN KEY ("conversationId") REFERENCES "Conversation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoryBranch" ADD CONSTRAINT "StoryBranch_parentId_fkey" FOREIGN KEY ("parentId") REFERENCES "StoryBranch"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "VectorMemory" ADD CONSTRAINT "VectorMemory_conversationId_fkey" FOREIGN KEY ("conversationId") REFERENCES "Conversation"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "VectorMemory" ADD CONSTRAINT "VectorMemory_characterId_fkey" FOREIGN KEY ("characterId") REFERENCES "Character"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "VectorMemory" ADD CONSTRAINT "VectorMemory_knowledgeId_fkey" FOREIGN KEY ("knowledgeId") REFERENCES "CharacterKnowledge"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
3
apps/backend/prisma/migrations/migration_lock.toml
Normal file
3
apps/backend/prisma/migrations/migration_lock.toml
Normal file
@@ -0,0 +1,3 @@
|
||||
# Please do not edit this file manually
|
||||
# It should be added in your version-control system (e.g., Git)
|
||||
provider = "postgresql"
|
||||
@@ -7,6 +7,7 @@
|
||||
generator client {
|
||||
provider = "prisma-client-js"
|
||||
previewFeatures = ["strictUndefinedChecks"]
|
||||
engineType = "binary"
|
||||
}
|
||||
|
||||
datasource db {
|
||||
|
||||
@@ -29,7 +29,6 @@ async function main() {
|
||||
id: '00000000-0000-0000-0000-000000000001',
|
||||
name: 'Alice',
|
||||
personalityPrompt: 'You are Alice, a curious and adventurous explorer who loves discovering new things. You are friendly, witty, and always eager to help.',
|
||||
backstory: 'Alice grew up in a small village at the edge of a vast forest. From a young age, she was fascinated by the unknown and would often venture into the woods to explore.',
|
||||
attributes: {
|
||||
traits: ['curious', 'brave', 'witty', 'friendly'],
|
||||
age: 25,
|
||||
|
||||
@@ -1,8 +1,32 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { APP_GUARD } from '@nestjs/core';
|
||||
import { PrismaModule } from './prisma/prisma.module';
|
||||
import { AuthModule } from './auth/auth.module';
|
||||
import { UserModule } from './user/user.module';
|
||||
import { CharacterModule } from './character/character.module';
|
||||
import { LLMModule } from './llm/llm.module';
|
||||
import { VectorModule } from './vector/vector.module';
|
||||
import { ChatModule } from './chat/chat.module';
|
||||
import { ImportModule } from './import/import.module';
|
||||
import { JwtAuthGuard } from './auth/guards/jwt-auth.guard';
|
||||
|
||||
@Module({
|
||||
imports: [],
|
||||
imports: [
|
||||
PrismaModule,
|
||||
AuthModule,
|
||||
UserModule,
|
||||
CharacterModule,
|
||||
LLMModule,
|
||||
VectorModule,
|
||||
ChatModule,
|
||||
ImportModule,
|
||||
],
|
||||
controllers: [],
|
||||
providers: [],
|
||||
providers: [
|
||||
{
|
||||
provide: APP_GUARD,
|
||||
useClass: JwtAuthGuard,
|
||||
},
|
||||
],
|
||||
})
|
||||
export class AppModule {}
|
||||
|
||||
125
apps/backend/src/auth/auth.controller.ts
Normal file
125
apps/backend/src/auth/auth.controller.ts
Normal file
@@ -0,0 +1,125 @@
|
||||
import { Controller, Post, Get, Body, Query, HttpCode, HttpStatus, UseGuards, Req, Res } from '@nestjs/common';
|
||||
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiQuery } from '@nestjs/swagger';
|
||||
import { Request, Response } from 'express';
|
||||
import { AuthService } from './auth.service';
|
||||
import { KeycloakService } from './keycloak.service';
|
||||
import { LoginDto, RefreshTokenDto } from './dto/login.dto';
|
||||
import { AuthResponseDto } from './dto/auth-response.dto';
|
||||
import { KeycloakLoginUrlDto, KeycloakCallbackQueryDto, KeycloakConfigDto } from './dto/keycloak.dto';
|
||||
import { Public } from '../common/decorators/public.decorator';
|
||||
import { KeycloakAuthGuard } from './guards/keycloak-auth.guard';
|
||||
|
||||
@ApiTags('auth')
|
||||
@Controller('auth')
|
||||
export class AuthController {
|
||||
constructor(
|
||||
private authService: AuthService,
|
||||
private keycloakService: KeycloakService,
|
||||
) {}
|
||||
|
||||
@Public()
|
||||
@Post('login')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
@ApiOperation({ summary: 'Login with email and password' })
|
||||
@ApiResponse({ status: 200, description: 'Login successful', type: AuthResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Invalid credentials' })
|
||||
async login(@Body() loginDto: LoginDto): Promise<AuthResponseDto> {
|
||||
return this.authService.login(loginDto);
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Post('refresh')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
@ApiOperation({ summary: 'Refresh access token' })
|
||||
@ApiResponse({ status: 200, description: 'Token refreshed', type: AuthResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Invalid refresh token' })
|
||||
async refreshTokens(@Body() refreshTokenDto: RefreshTokenDto): Promise<AuthResponseDto> {
|
||||
return this.authService.refreshTokens(refreshTokenDto.refreshToken);
|
||||
}
|
||||
|
||||
// ==================== KEYCLOAK OAUTH FLOW ====================
|
||||
|
||||
@Public()
|
||||
@Get('keycloak/config')
|
||||
@ApiOperation({ summary: 'Get Keycloak configuration for frontend' })
|
||||
@ApiResponse({ status: 200, description: 'Keycloak config', type: KeycloakConfigDto })
|
||||
getKeycloakConfig(): KeycloakConfigDto {
|
||||
return this.keycloakService.getConfig();
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('keycloak/login')
|
||||
@ApiOperation({ summary: 'Get Keycloak login URL (initiates OAuth flow)' })
|
||||
@ApiQuery({ name: 'redirectTo', required: false, description: 'Frontend path to redirect after login' })
|
||||
@ApiResponse({ status: 200, description: 'Login URL generated', type: KeycloakLoginUrlDto })
|
||||
@ApiResponse({ status: 400, description: 'Keycloak not enabled' })
|
||||
keycloakLogin(@Query('redirectTo') redirectTo?: string): KeycloakLoginUrlDto {
|
||||
return this.keycloakService.generateLoginUrl(redirectTo);
|
||||
}
|
||||
|
||||
@Public()
|
||||
@Get('keycloak/callback')
|
||||
@ApiOperation({ summary: 'Keycloak OAuth callback endpoint' })
|
||||
@ApiQuery({ name: 'code', required: true, description: 'Authorization code from Keycloak' })
|
||||
@ApiQuery({ name: 'state', required: true, description: 'State parameter for CSRF validation' })
|
||||
@ApiQuery({ name: 'error', required: false, description: 'Error message if authentication failed' })
|
||||
@ApiQuery({ name: 'error_description', required: false, description: 'Error description' })
|
||||
@ApiResponse({ status: 302, description: 'Redirect to frontend with tokens' })
|
||||
@ApiResponse({ status: 401, description: 'Authentication failed' })
|
||||
async keycloakCallback(
|
||||
@Query() query: KeycloakCallbackQueryDto,
|
||||
@Res() res: Response,
|
||||
): Promise<void> {
|
||||
// Handle errors from Keycloak
|
||||
if (query.error) {
|
||||
const frontendUrl = process.env.FRONTEND_URL || 'http://localhost:5173';
|
||||
const errorMessage = encodeURIComponent(query.error_description || query.error);
|
||||
return res.redirect(`${frontendUrl}/login?error=${errorMessage}`);
|
||||
}
|
||||
|
||||
if (!query.code) {
|
||||
const frontendUrl = process.env.FRONTEND_URL || 'http://localhost:5173';
|
||||
return res.redirect(`${frontendUrl}/login?error=Missing authorization code`);
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
const result = await this.keycloakService.handleCallback(query.code, query.state);
|
||||
|
||||
// Redirect to frontend with tokens
|
||||
const frontendUrl = process.env.FRONTEND_URL || 'http://localhost:5173';
|
||||
const redirectPath = result.redirectTo || '/characters';
|
||||
|
||||
// Build redirect URL with tokens
|
||||
const params = new URLSearchParams({
|
||||
accessToken: result.authResponse.accessToken,
|
||||
refreshToken: result.authResponse.refreshToken,
|
||||
});
|
||||
|
||||
return res.redirect(`${frontendUrl}${redirectPath}?${params.toString()}`);
|
||||
}
|
||||
|
||||
// ==================== KEYCLOAK BEARER TOKEN (Legacy/Alternative) ====================
|
||||
|
||||
@Public()
|
||||
@Post('keycloak')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
@UseGuards(KeycloakAuthGuard)
|
||||
@ApiOperation({ summary: 'Login with Keycloak bearer token (Authorization: Bearer <keycloak-jwt>)' })
|
||||
@ApiBearerAuth()
|
||||
@ApiResponse({ status: 200, description: 'Login successful', type: AuthResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Invalid Keycloak token' })
|
||||
async keycloakBearerLogin(@Req() req: Request): Promise<AuthResponseDto> {
|
||||
// The Keycloak guard validates the token and attaches the user to req.user
|
||||
const keycloakUser = req.user as {
|
||||
userId: string;
|
||||
email: string;
|
||||
role: string;
|
||||
};
|
||||
|
||||
return this.authService.generateTokensFromUser(
|
||||
keycloakUser.userId,
|
||||
keycloakUser.email,
|
||||
keycloakUser.role,
|
||||
);
|
||||
}
|
||||
}
|
||||
33
apps/backend/src/auth/auth.module.ts
Normal file
33
apps/backend/src/auth/auth.module.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { JwtModule } from '@nestjs/jwt';
|
||||
import { PassportModule } from '@nestjs/passport';
|
||||
import { AuthService } from './auth.service';
|
||||
import { AuthController } from './auth.controller';
|
||||
import { KeycloakService } from './keycloak.service';
|
||||
import { JwtStrategy } from './strategies/jwt.strategy';
|
||||
import { LocalStrategy } from './strategies/local.strategy';
|
||||
import { KeycloakStrategy } from './strategies/keycloak.strategy';
|
||||
import { PrismaModule } from '../prisma/prisma.module';
|
||||
|
||||
@Module({
|
||||
imports: [
|
||||
PrismaModule,
|
||||
PassportModule,
|
||||
JwtModule.registerAsync({
|
||||
useFactory: () => ({
|
||||
secret: process.env.JWT_SECRET || 'dev-jwt-secret-change-in-production',
|
||||
signOptions: { expiresIn: '1h' },
|
||||
}),
|
||||
}),
|
||||
],
|
||||
providers: [
|
||||
AuthService,
|
||||
KeycloakService,
|
||||
JwtStrategy,
|
||||
LocalStrategy,
|
||||
KeycloakStrategy,
|
||||
],
|
||||
controllers: [AuthController],
|
||||
exports: [AuthService, KeycloakService],
|
||||
})
|
||||
export class AuthModule {}
|
||||
112
apps/backend/src/auth/auth.service.ts
Normal file
112
apps/backend/src/auth/auth.service.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
import { Injectable, UnauthorizedException, ConflictException } from '@nestjs/common';
|
||||
import { JwtService } from '@nestjs/jwt';
|
||||
import { PrismaService } from '../prisma/prisma.service';
|
||||
import * as bcrypt from 'bcrypt';
|
||||
import { LoginDto, RegisterDto } from './dto/login.dto';
|
||||
import { AuthResponseDto } from './dto/auth-response.dto';
|
||||
import { User } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class AuthService {
|
||||
constructor(
|
||||
private prisma: PrismaService,
|
||||
private jwtService: JwtService,
|
||||
) {}
|
||||
|
||||
async validateUser(email: string, password: string): Promise<Omit<User, 'passwordHash'> | null> {
|
||||
const user = await this.prisma.user.findUnique({
|
||||
where: { email },
|
||||
});
|
||||
|
||||
if (user && user.passwordHash) {
|
||||
const isMatch = await bcrypt.compare(password, user.passwordHash);
|
||||
if (isMatch) {
|
||||
const { passwordHash, ...result } = user;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async login(loginDto: LoginDto): Promise<AuthResponseDto> {
|
||||
const user = await this.validateUser(loginDto.email, loginDto.password);
|
||||
|
||||
if (!user) {
|
||||
throw new UnauthorizedException('Invalid credentials');
|
||||
}
|
||||
|
||||
if (!user.isActive) {
|
||||
throw new UnauthorizedException('Account is deactivated');
|
||||
}
|
||||
|
||||
return this.generateTokens(user);
|
||||
}
|
||||
|
||||
async register(registerDto: RegisterDto): Promise<AuthResponseDto> {
|
||||
const existingUser = await this.prisma.user.findFirst({
|
||||
where: {
|
||||
OR: [
|
||||
{ email: registerDto.email },
|
||||
{ username: registerDto.username },
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
if (existingUser) {
|
||||
throw new ConflictException('Email or username already exists');
|
||||
}
|
||||
|
||||
const hashedPassword = await bcrypt.hash(registerDto.password, 10);
|
||||
|
||||
const user = await this.prisma.user.create({
|
||||
data: {
|
||||
email: registerDto.email,
|
||||
username: registerDto.username,
|
||||
passwordHash: hashedPassword,
|
||||
},
|
||||
});
|
||||
|
||||
const { passwordHash, ...userWithoutPassword } = user;
|
||||
return this.generateTokens(userWithoutPassword);
|
||||
}
|
||||
|
||||
async refreshTokens(refreshToken: string): Promise<AuthResponseDto> {
|
||||
try {
|
||||
const payload = this.jwtService.verify(refreshToken, {
|
||||
secret: process.env.JWT_SECRET,
|
||||
});
|
||||
|
||||
const user = await this.prisma.user.findUnique({
|
||||
where: { id: payload.sub },
|
||||
});
|
||||
|
||||
if (!user || !user.isActive) {
|
||||
throw new UnauthorizedException('Invalid refresh token');
|
||||
}
|
||||
|
||||
const { passwordHash, ...userWithoutPassword } = user;
|
||||
return this.generateTokens(userWithoutPassword);
|
||||
} catch {
|
||||
throw new UnauthorizedException('Invalid refresh token');
|
||||
}
|
||||
}
|
||||
|
||||
generateTokensFromUser(userId: string, email: string, role: string): AuthResponseDto {
|
||||
const payload = { sub: userId, email, role };
|
||||
|
||||
return {
|
||||
accessToken: this.jwtService.sign(payload),
|
||||
refreshToken: this.jwtService.sign(payload, { expiresIn: '7d' }),
|
||||
user: {
|
||||
id: userId,
|
||||
email,
|
||||
username: email.split('@')[0],
|
||||
role,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private generateTokens(user: Omit<User, 'passwordHash'>): AuthResponseDto {
|
||||
return this.generateTokensFromUser(user.id, user.email, user.role);
|
||||
}
|
||||
}
|
||||
26
apps/backend/src/auth/dto/auth-response.dto.ts
Normal file
26
apps/backend/src/auth/dto/auth-response.dto.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
import { ApiProperty } from '@nestjs/swagger';
|
||||
|
||||
class UserDto {
|
||||
@ApiProperty({ description: 'User ID', example: '550e8400-e29b-41d4-a716-446655440000' })
|
||||
id: string;
|
||||
|
||||
@ApiProperty({ description: 'User email', example: 'admin@dreamchat.local' })
|
||||
email: string;
|
||||
|
||||
@ApiProperty({ description: 'User username', example: 'admin' })
|
||||
username: string;
|
||||
|
||||
@ApiProperty({ description: 'User role', example: 'USER', enum: ['USER', 'ADMIN'] })
|
||||
role: string;
|
||||
}
|
||||
|
||||
export class AuthResponseDto {
|
||||
@ApiProperty({ description: 'JWT access token' })
|
||||
accessToken: string;
|
||||
|
||||
@ApiProperty({ description: 'JWT refresh token' })
|
||||
refreshToken: string;
|
||||
|
||||
@ApiProperty({ description: 'User information', type: UserDto })
|
||||
user: UserDto;
|
||||
}
|
||||
37
apps/backend/src/auth/dto/keycloak.dto.ts
Normal file
37
apps/backend/src/auth/dto/keycloak.dto.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
|
||||
|
||||
export class KeycloakConfigDto {
|
||||
@ApiProperty({ description: 'Whether Keycloak authentication is enabled' })
|
||||
enabled: boolean;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Keycloak realm URL' })
|
||||
url?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Keycloak realm name' })
|
||||
realm?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Keycloak client ID' })
|
||||
clientId?: string;
|
||||
}
|
||||
|
||||
export class KeycloakLoginUrlDto {
|
||||
@ApiProperty({ description: 'Keycloak login URL to redirect the user to' })
|
||||
loginUrl: string;
|
||||
|
||||
@ApiProperty({ description: 'State parameter for CSRF protection' })
|
||||
state: string;
|
||||
}
|
||||
|
||||
export class KeycloakCallbackQueryDto {
|
||||
@ApiPropertyOptional({ description: 'Authorization code from Keycloak' })
|
||||
code?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Error message if authentication failed' })
|
||||
error?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Error description' })
|
||||
error_description?: string;
|
||||
|
||||
@ApiProperty({ description: 'State parameter for CSRF validation' })
|
||||
state: string;
|
||||
}
|
||||
35
apps/backend/src/auth/dto/login.dto.ts
Normal file
35
apps/backend/src/auth/dto/login.dto.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { IsString, IsEmail, MinLength } from 'class-validator';
|
||||
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
|
||||
|
||||
export class LoginDto {
|
||||
@ApiProperty({ description: 'User email address', example: 'admin@dreamchat.local' })
|
||||
@IsEmail()
|
||||
email: string;
|
||||
|
||||
@ApiProperty({ description: 'User password', example: 'password123' })
|
||||
@IsString()
|
||||
@MinLength(6)
|
||||
password: string;
|
||||
}
|
||||
|
||||
export class RegisterDto {
|
||||
@ApiProperty({ description: 'User email address', example: 'user@example.com' })
|
||||
@IsEmail()
|
||||
email: string;
|
||||
|
||||
@ApiProperty({ description: 'Username', example: 'myusername' })
|
||||
@IsString()
|
||||
@MinLength(3)
|
||||
username: string;
|
||||
|
||||
@ApiProperty({ description: 'User password', example: 'password123' })
|
||||
@IsString()
|
||||
@MinLength(6)
|
||||
password: string;
|
||||
}
|
||||
|
||||
export class RefreshTokenDto {
|
||||
@ApiProperty({ description: 'Refresh token', example: 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...' })
|
||||
@IsString()
|
||||
refreshToken: string;
|
||||
}
|
||||
24
apps/backend/src/auth/guards/jwt-auth.guard.ts
Normal file
24
apps/backend/src/auth/guards/jwt-auth.guard.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import { Injectable, ExecutionContext } from '@nestjs/common';
|
||||
import { AuthGuard } from '@nestjs/passport';
|
||||
import { Reflector } from '@nestjs/core';
|
||||
import { IS_PUBLIC_KEY } from '../../common/decorators/public.decorator';
|
||||
|
||||
@Injectable()
|
||||
export class JwtAuthGuard extends AuthGuard('jwt') {
|
||||
constructor(private reflector: Reflector) {
|
||||
super();
|
||||
}
|
||||
|
||||
canActivate(context: ExecutionContext) {
|
||||
const isPublic = this.reflector.getAllAndOverride<boolean>(IS_PUBLIC_KEY, [
|
||||
context.getHandler(),
|
||||
context.getClass(),
|
||||
]);
|
||||
|
||||
if (isPublic) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return super.canActivate(context);
|
||||
}
|
||||
}
|
||||
13
apps/backend/src/auth/guards/keycloak-auth.guard.ts
Normal file
13
apps/backend/src/auth/guards/keycloak-auth.guard.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import { Injectable, ExecutionContext } from '@nestjs/common';
|
||||
import { AuthGuard } from '@nestjs/passport';
|
||||
|
||||
@Injectable()
|
||||
export class KeycloakAuthGuard extends AuthGuard('keycloak') {
|
||||
canActivate(context: ExecutionContext) {
|
||||
// Skip if Keycloak is not enabled
|
||||
if (process.env.KEYCLOAK_ENABLED !== 'true') {
|
||||
return false;
|
||||
}
|
||||
return super.canActivate(context);
|
||||
}
|
||||
}
|
||||
5
apps/backend/src/auth/guards/local-auth.guard.ts
Normal file
5
apps/backend/src/auth/guards/local-auth.guard.ts
Normal file
@@ -0,0 +1,5 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { AuthGuard } from '@nestjs/passport';
|
||||
|
||||
@Injectable()
|
||||
export class LocalAuthGuard extends AuthGuard('local') {}
|
||||
333
apps/backend/src/auth/keycloak.service.ts
Normal file
333
apps/backend/src/auth/keycloak.service.ts
Normal file
@@ -0,0 +1,333 @@
|
||||
import { Injectable, UnauthorizedException, BadRequestException } from '@nestjs/common';
|
||||
import { JwtService } from '@nestjs/jwt';
|
||||
import { PrismaService } from '../prisma/prisma.service';
|
||||
import { UserRole } from '@prisma/client';
|
||||
import { AuthResponseDto } from './dto/auth-response.dto';
|
||||
import * as crypto from 'crypto';
|
||||
|
||||
interface KeycloakTokenResponse {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_expires_in: number;
|
||||
refresh_token: string;
|
||||
token_type: string;
|
||||
id_token?: string;
|
||||
session_state?: string;
|
||||
scope: string;
|
||||
}
|
||||
|
||||
interface KeycloakUserInfo {
|
||||
sub: string;
|
||||
email?: string;
|
||||
email_verified?: boolean;
|
||||
name?: string;
|
||||
preferred_username?: string;
|
||||
given_name?: string;
|
||||
family_name?: string;
|
||||
groups?: string[];
|
||||
// Keycloak may also put groups in these locations
|
||||
realm_access?: {
|
||||
roles?: string[];
|
||||
};
|
||||
resource_access?: {
|
||||
[key: string]: {
|
||||
roles?: string[];
|
||||
};
|
||||
};
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class KeycloakService {
|
||||
private readonly keycloakEnabled: boolean;
|
||||
private readonly keycloakUrl: string;
|
||||
private readonly keycloakRealm: string;
|
||||
private readonly clientId: string;
|
||||
private readonly clientSecret: string;
|
||||
private readonly redirectUri: string;
|
||||
private stateStore: Map<string, { createdAt: number; redirectTo?: string }> = new Map();
|
||||
|
||||
constructor(
|
||||
private prisma: PrismaService,
|
||||
private jwtService: JwtService,
|
||||
) {
|
||||
this.keycloakEnabled = process.env.KEYCLOAK_ENABLED === 'true';
|
||||
this.keycloakUrl = process.env.KEYCLOAK_URL || '';
|
||||
this.keycloakRealm = process.env.KEYCLOAK_REALM || '';
|
||||
this.clientId = process.env.KEYCLOAK_CLIENT_ID || '';
|
||||
this.clientSecret = process.env.KEYCLOAK_CLIENT_SECRET || '';
|
||||
this.redirectUri = process.env.KEYCLOAK_REDIRECT_URI || 'http://localhost:3000/api/auth/keycloak/callback';
|
||||
|
||||
// Clean up old state entries every 5 minutes
|
||||
setInterval(() => this.cleanupState(), 5 * 60 * 1000);
|
||||
}
|
||||
|
||||
isEnabled(): boolean {
|
||||
return this.keycloakEnabled;
|
||||
}
|
||||
|
||||
getConfig() {
|
||||
return {
|
||||
enabled: this.keycloakEnabled,
|
||||
url: this.keycloakUrl || undefined,
|
||||
realm: this.keycloakRealm || undefined,
|
||||
clientId: this.clientId || undefined,
|
||||
};
|
||||
}
|
||||
|
||||
generateLoginUrl(redirectTo?: string): { loginUrl: string; state: string } {
|
||||
if (!this.keycloakEnabled) {
|
||||
throw new BadRequestException('Keycloak is not enabled');
|
||||
}
|
||||
|
||||
const state = crypto.randomBytes(32).toString('hex');
|
||||
this.stateStore.set(state, {
|
||||
createdAt: Date.now(),
|
||||
redirectTo,
|
||||
});
|
||||
|
||||
const baseUrl = `${this.keycloakUrl}/realms/${this.keycloakRealm}/protocol/openid-connect/auth`;
|
||||
const params = new URLSearchParams({
|
||||
client_id: this.clientId,
|
||||
redirect_uri: this.redirectUri,
|
||||
response_type: 'code',
|
||||
scope: 'openid email profile',
|
||||
state,
|
||||
});
|
||||
|
||||
return {
|
||||
loginUrl: `${baseUrl}?${params.toString()}`,
|
||||
state,
|
||||
};
|
||||
}
|
||||
|
||||
async handleCallback(code: string, state: string): Promise<{
|
||||
authResponse: AuthResponseDto;
|
||||
redirectTo?: string;
|
||||
}> {
|
||||
if (!this.keycloakEnabled) {
|
||||
throw new BadRequestException('Keycloak is not enabled');
|
||||
}
|
||||
|
||||
// Validate state
|
||||
const stateData = this.stateStore.get(state);
|
||||
if (!stateData) {
|
||||
throw new UnauthorizedException('Invalid or expired state parameter');
|
||||
}
|
||||
this.stateStore.delete(state);
|
||||
|
||||
// Exchange code for tokens
|
||||
const tokens = await this.exchangeCodeForTokens(code);
|
||||
|
||||
// Get user info from Keycloak
|
||||
const userInfo = await this.getUserInfo(tokens.access_token);
|
||||
|
||||
// DEBUG: Log the full userinfo to see what groups are actually returned
|
||||
console.log('[Keycloak Debug] UserInfo:', JSON.stringify(userInfo, null, 2));
|
||||
|
||||
// Validate user info
|
||||
if (!userInfo.sub) {
|
||||
throw new UnauthorizedException('Invalid user info from Keycloak');
|
||||
}
|
||||
|
||||
// Check authorization requirements
|
||||
this.checkAuthorization(userInfo);
|
||||
|
||||
// Find or create user
|
||||
const user = await this.findOrCreateUser(userInfo);
|
||||
|
||||
if (!user.isActive) {
|
||||
throw new UnauthorizedException('Account is deactivated');
|
||||
}
|
||||
|
||||
// Generate DreamChat tokens
|
||||
const authResponse = this.generateTokens(user);
|
||||
|
||||
return {
|
||||
authResponse,
|
||||
redirectTo: stateData.redirectTo,
|
||||
};
|
||||
}
|
||||
|
||||
private async exchangeCodeForTokens(code: string): Promise<KeycloakTokenResponse> {
|
||||
const tokenUrl = `${this.keycloakUrl}/realms/${this.keycloakRealm}/protocol/openid-connect/token`;
|
||||
|
||||
const params = new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
client_id: this.clientId,
|
||||
client_secret: this.clientSecret,
|
||||
code,
|
||||
redirect_uri: this.redirectUri,
|
||||
});
|
||||
|
||||
const response = await fetch(tokenUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: params.toString(),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new UnauthorizedException(`Failed to exchange code for tokens: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
private async getUserInfo(accessToken: string): Promise<KeycloakUserInfo> {
|
||||
const userInfoUrl = `${this.keycloakUrl}/realms/${this.keycloakRealm}/protocol/openid-connect/userinfo`;
|
||||
|
||||
const response = await fetch(userInfoUrl, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new UnauthorizedException('Failed to get user info from Keycloak');
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
private async findOrCreateUser(userInfo: KeycloakUserInfo) {
|
||||
const keycloakSub = userInfo.sub;
|
||||
const email = userInfo.email;
|
||||
const username = userInfo.preferred_username || email || keycloakSub;
|
||||
|
||||
// Try to find user by keycloakSub first
|
||||
let user = await this.prisma.user.findUnique({
|
||||
where: { keycloakSub },
|
||||
});
|
||||
|
||||
if (user) {
|
||||
return user;
|
||||
}
|
||||
|
||||
// Try to find by email and link accounts
|
||||
if (email) {
|
||||
const existingUser = await this.prisma.user.findUnique({
|
||||
where: { email },
|
||||
});
|
||||
|
||||
if (existingUser) {
|
||||
// Link existing user to Keycloak
|
||||
return this.prisma.user.update({
|
||||
where: { id: existingUser.id },
|
||||
data: { keycloakSub },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check if auto-create is enabled
|
||||
const autoCreate = process.env.KEYCLOAK_AUTO_CREATE_USER !== 'false';
|
||||
if (!autoCreate) {
|
||||
throw new UnauthorizedException('User not found and auto-creation is disabled');
|
||||
}
|
||||
|
||||
// Create new user
|
||||
const defaultRole =
|
||||
process.env.KEYCLOAK_DEFAULT_USER_ROLE === 'ADMIN'
|
||||
? UserRole.ADMIN
|
||||
: UserRole.USER;
|
||||
|
||||
return this.prisma.user.create({
|
||||
data: {
|
||||
email: email || `${keycloakSub}@keycloak.local`,
|
||||
username: username,
|
||||
keycloakSub,
|
||||
role: defaultRole,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private checkAuthorization(userInfo: KeycloakUserInfo): void {
|
||||
const requiredGroup = process.env.KEYCLOAK_REQUIRED_GROUP;
|
||||
const requiredRole = process.env.KEYCLOAK_REQUIRED_ROLE;
|
||||
const requiredClientRole = process.env.KEYCLOAK_REQUIRED_CLIENT_ROLE;
|
||||
const requiredAttribute = process.env.KEYCLOAK_REQUIRED_ATTRIBUTE;
|
||||
|
||||
// Collect all possible sources of groups/roles
|
||||
const groups = userInfo.groups || [];
|
||||
const realmRoles = userInfo.realm_access?.roles || [];
|
||||
const clientRoles = userInfo.resource_access?.[this.clientId]?.roles || [];
|
||||
|
||||
console.log('[Keycloak Debug] Authorization Check:');
|
||||
console.log(' Required Group:', requiredGroup);
|
||||
console.log(' User Groups:', groups);
|
||||
console.log(' Realm Roles:', realmRoles);
|
||||
console.log(' Client Roles:', clientRoles);
|
||||
|
||||
// Check required group - try groups array, then realm roles, then client roles
|
||||
if (requiredGroup) {
|
||||
// Check in groups claim (most common location)
|
||||
const hasGroup = groups.includes(requiredGroup);
|
||||
// Also check in realm roles (sometimes groups are mapped as roles)
|
||||
const hasGroupAsRole = realmRoles.includes(requiredGroup);
|
||||
// Also check in client roles
|
||||
const hasGroupAsClientRole = clientRoles.includes(requiredGroup);
|
||||
|
||||
if (!hasGroup && !hasGroupAsRole && !hasGroupAsClientRole) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required group '${requiredGroup}' not found. Your groups: [${groups.join(', ')}], realm roles: [${realmRoles.join(', ')}], client roles: [${clientRoles.join(', ')}]`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check required realm role
|
||||
if (requiredRole) {
|
||||
if (!realmRoles.includes(requiredRole)) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required role '${requiredRole}' not found. Your roles: [${realmRoles.join(', ')}]`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check required client role
|
||||
if (requiredClientRole) {
|
||||
if (!clientRoles.includes(requiredClientRole)) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required client role '${requiredClientRole}' not found. Your client roles: [${clientRoles.join(', ')}]`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check required attribute
|
||||
if (requiredAttribute) {
|
||||
const [attrKey, attrValue] = requiredAttribute.split(':');
|
||||
if (userInfo[attrKey] !== attrValue) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required attribute '${attrKey}=${attrValue}' not found`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private generateTokens(user: { id: string; email: string; username: string; role: string }): AuthResponseDto {
|
||||
const payload = { sub: user.id, email: user.email, role: user.role };
|
||||
|
||||
return {
|
||||
accessToken: this.jwtService.sign(payload),
|
||||
refreshToken: this.jwtService.sign(payload, { expiresIn: '7d' }),
|
||||
user: {
|
||||
id: user.id,
|
||||
email: user.email,
|
||||
username: user.username,
|
||||
role: user.role,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private cleanupState(): void {
|
||||
const now = Date.now();
|
||||
const maxAge = 10 * 60 * 1000; // 10 minutes
|
||||
|
||||
for (const [state, data] of this.stateStore.entries()) {
|
||||
if (now - data.createdAt > maxAge) {
|
||||
this.stateStore.delete(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
37
apps/backend/src/auth/strategies/jwt.strategy.ts
Normal file
37
apps/backend/src/auth/strategies/jwt.strategy.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { Injectable, UnauthorizedException } from '@nestjs/common';
|
||||
import { PassportStrategy } from '@nestjs/passport';
|
||||
import { ExtractJwt, Strategy } from 'passport-jwt';
|
||||
import { PrismaService } from '../../prisma/prisma.service';
|
||||
|
||||
interface JwtPayload {
|
||||
sub: string;
|
||||
email: string;
|
||||
role: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class JwtStrategy extends PassportStrategy(Strategy, 'jwt') {
|
||||
constructor(private prisma: PrismaService) {
|
||||
super({
|
||||
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
|
||||
ignoreExpiration: false,
|
||||
secretOrKey: process.env.JWT_SECRET || 'dev-jwt-secret-change-in-production',
|
||||
});
|
||||
}
|
||||
|
||||
async validate(payload: JwtPayload) {
|
||||
const user = await this.prisma.user.findUnique({
|
||||
where: { id: payload.sub },
|
||||
});
|
||||
|
||||
if (!user || !user.isActive) {
|
||||
throw new UnauthorizedException('User not found or inactive');
|
||||
}
|
||||
|
||||
return {
|
||||
userId: payload.sub,
|
||||
email: payload.email,
|
||||
role: payload.role,
|
||||
};
|
||||
}
|
||||
}
|
||||
191
apps/backend/src/auth/strategies/keycloak.strategy.ts
Normal file
191
apps/backend/src/auth/strategies/keycloak.strategy.ts
Normal file
@@ -0,0 +1,191 @@
|
||||
import { Injectable, UnauthorizedException } from '@nestjs/common';
|
||||
import { PassportStrategy } from '@nestjs/passport';
|
||||
import { Strategy, ExtractJwt } from 'passport-jwt';
|
||||
import { PrismaService } from '../../prisma/prisma.service';
|
||||
import { UserRole } from '@prisma/client';
|
||||
|
||||
interface KeycloakJwtPayload {
|
||||
sub: string;
|
||||
email?: string;
|
||||
preferred_username?: string;
|
||||
realm_access?: {
|
||||
roles?: string[];
|
||||
};
|
||||
resource_access?: {
|
||||
[key: string]: {
|
||||
roles?: string[];
|
||||
};
|
||||
};
|
||||
groups?: string[];
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class KeycloakStrategy extends PassportStrategy(Strategy, 'keycloak') {
|
||||
private keycloakEnabled: boolean;
|
||||
|
||||
constructor(private prisma: PrismaService) {
|
||||
const keycloakEnabled = process.env.KEYCLOAK_ENABLED === 'true';
|
||||
const keycloakUrl = process.env.KEYCLOAK_URL || '';
|
||||
const keycloakRealm = process.env.KEYCLOAK_REALM || '';
|
||||
|
||||
super({
|
||||
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
|
||||
ignoreExpiration: false,
|
||||
secretOrKey: keycloakEnabled
|
||||
? undefined
|
||||
: 'keycloak-not-enabled-placeholder',
|
||||
secretOrKeyProvider: keycloakEnabled
|
||||
? async (request, rawJwtToken, done) => {
|
||||
try {
|
||||
// Fetch Keycloak realm public key
|
||||
const response = await fetch(
|
||||
`${keycloakUrl}/realms/${keycloakRealm}`,
|
||||
);
|
||||
const realmInfo = await response.json();
|
||||
const publicKey = realmInfo.public_key;
|
||||
if (!publicKey) {
|
||||
throw new Error('No public key found in Keycloak realm');
|
||||
}
|
||||
done(
|
||||
null,
|
||||
`-----BEGIN PUBLIC KEY-----\n${publicKey}\n-----END PUBLIC KEY-----`,
|
||||
);
|
||||
} catch (error) {
|
||||
done(error as Error, '');
|
||||
}
|
||||
}
|
||||
: undefined,
|
||||
algorithms: ['RS256'],
|
||||
});
|
||||
|
||||
this.keycloakEnabled = keycloakEnabled;
|
||||
}
|
||||
|
||||
async validate(payload: KeycloakJwtPayload): Promise<{
|
||||
userId: string;
|
||||
email: string;
|
||||
role: UserRole;
|
||||
keycloakSub: string;
|
||||
}> {
|
||||
if (!this.keycloakEnabled) {
|
||||
throw new UnauthorizedException('Keycloak is not enabled');
|
||||
}
|
||||
|
||||
const keycloakSub = payload.sub;
|
||||
const email = payload.email;
|
||||
const username = payload.preferred_username || email;
|
||||
|
||||
if (!keycloakSub) {
|
||||
throw new UnauthorizedException('Invalid Keycloak token');
|
||||
}
|
||||
|
||||
// Check authorization requirements
|
||||
this.checkAuthorization(payload);
|
||||
|
||||
// Find or create user
|
||||
let user = await this.prisma.user.findUnique({
|
||||
where: { keycloakSub },
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
// Auto-create user if enabled
|
||||
const autoCreate = process.env.KEYCLOAK_AUTO_CREATE_USER !== 'false';
|
||||
if (!autoCreate) {
|
||||
throw new UnauthorizedException(
|
||||
'User not found and auto-creation is disabled',
|
||||
);
|
||||
}
|
||||
|
||||
// Check if email already exists
|
||||
if (email) {
|
||||
const existingUser = await this.prisma.user.findUnique({
|
||||
where: { email },
|
||||
});
|
||||
if (existingUser) {
|
||||
// Link existing user to Keycloak
|
||||
user = await this.prisma.user.update({
|
||||
where: { id: existingUser.id },
|
||||
data: { keycloakSub },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
// Create new user
|
||||
const defaultRole =
|
||||
process.env.KEYCLOAK_DEFAULT_USER_ROLE === 'ADMIN'
|
||||
? UserRole.ADMIN
|
||||
: UserRole.USER;
|
||||
|
||||
user = await this.prisma.user.create({
|
||||
data: {
|
||||
email: email || `${keycloakSub}@keycloak.local`,
|
||||
username: username || keycloakSub,
|
||||
keycloakSub,
|
||||
role: defaultRole,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!user.isActive) {
|
||||
throw new UnauthorizedException('Account is deactivated');
|
||||
}
|
||||
|
||||
return {
|
||||
userId: user.id,
|
||||
email: user.email,
|
||||
role: user.role,
|
||||
keycloakSub,
|
||||
};
|
||||
}
|
||||
|
||||
private checkAuthorization(payload: KeycloakJwtPayload): void {
|
||||
const requiredGroup = process.env.KEYCLOAK_REQUIRED_GROUP;
|
||||
const requiredRole = process.env.KEYCLOAK_REQUIRED_ROLE;
|
||||
const requiredClientRole = process.env.KEYCLOAK_REQUIRED_CLIENT_ROLE;
|
||||
const requiredAttribute = process.env.KEYCLOAK_REQUIRED_ATTRIBUTE;
|
||||
|
||||
// Check required group
|
||||
if (requiredGroup) {
|
||||
const groups = payload.groups || [];
|
||||
if (!groups.includes(requiredGroup)) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required group '${requiredGroup}' not found`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check required realm role
|
||||
if (requiredRole) {
|
||||
const roles = payload.realm_access?.roles || [];
|
||||
if (!roles.includes(requiredRole)) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required role '${requiredRole}' not found`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check required client role
|
||||
if (requiredClientRole) {
|
||||
const clientId = process.env.KEYCLOAK_CLIENT_ID || '';
|
||||
const clientRoles = payload.resource_access?.[clientId]?.roles || [];
|
||||
if (!clientRoles.includes(requiredClientRole)) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required client role '${requiredClientRole}' not found`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check required attribute
|
||||
if (requiredAttribute) {
|
||||
const [attrKey, attrValue] = requiredAttribute.split(':');
|
||||
if (payload[attrKey] !== attrValue) {
|
||||
throw new UnauthorizedException(
|
||||
`Access denied: required attribute '${attrKey}=${attrValue}' not found`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
21
apps/backend/src/auth/strategies/local.strategy.ts
Normal file
21
apps/backend/src/auth/strategies/local.strategy.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
import { Injectable, UnauthorizedException } from '@nestjs/common';
|
||||
import { PassportStrategy } from '@nestjs/passport';
|
||||
import { Strategy } from 'passport-local';
|
||||
import { AuthService } from '../auth.service';
|
||||
|
||||
@Injectable()
|
||||
export class LocalStrategy extends PassportStrategy(Strategy, 'local') {
|
||||
constructor(private authService: AuthService) {
|
||||
super({
|
||||
usernameField: 'email',
|
||||
});
|
||||
}
|
||||
|
||||
async validate(email: string, password: string): Promise<any> {
|
||||
const user = await this.authService.validateUser(email, password);
|
||||
if (!user) {
|
||||
throw new UnauthorizedException('Invalid credentials');
|
||||
}
|
||||
return user;
|
||||
}
|
||||
}
|
||||
85
apps/backend/src/character/character.controller.ts
Normal file
85
apps/backend/src/character/character.controller.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
import {
|
||||
Controller,
|
||||
Get,
|
||||
Post,
|
||||
Put,
|
||||
Delete,
|
||||
Body,
|
||||
Param,
|
||||
} from '@nestjs/common';
|
||||
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiParam } from '@nestjs/swagger';
|
||||
import { CharacterService } from './character.service';
|
||||
import { CreateCharacterDto, UpdateCharacterDto } from './dto/create-character.dto';
|
||||
import { CharacterResponseDto } from './dto/character-response.dto';
|
||||
import { CurrentUser } from '../common/decorators/current-user.decorator';
|
||||
import { Character } from '@prisma/client';
|
||||
|
||||
@ApiTags('characters')
|
||||
@ApiBearerAuth()
|
||||
@Controller('characters')
|
||||
export class CharacterController {
|
||||
constructor(private characterService: CharacterService) {}
|
||||
|
||||
@Post()
|
||||
@ApiOperation({ summary: 'Create a new character' })
|
||||
@ApiResponse({ status: 201, description: 'Character created', type: CharacterResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
async create(
|
||||
@CurrentUser('userId') userId: string,
|
||||
@Body() createCharacterDto: CreateCharacterDto,
|
||||
): Promise<Character> {
|
||||
return this.characterService.create(userId, createCharacterDto);
|
||||
}
|
||||
|
||||
@Get()
|
||||
@ApiOperation({ summary: 'Get all characters for current user' })
|
||||
@ApiResponse({ status: 200, description: 'List of characters', type: [CharacterResponseDto] })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
async findAll(@CurrentUser('userId') userId: string): Promise<Character[]> {
|
||||
return this.characterService.findAllByUser(userId);
|
||||
}
|
||||
|
||||
@Get(':id')
|
||||
@ApiOperation({ summary: 'Get character by ID' })
|
||||
@ApiParam({ name: 'id', description: 'Character ID' })
|
||||
@ApiResponse({ status: 200, description: 'Character found', type: CharacterResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 403, description: 'Access denied' })
|
||||
@ApiResponse({ status: 404, description: 'Character not found' })
|
||||
async findOne(
|
||||
@Param('id') id: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<Character> {
|
||||
return this.characterService.findById(id, userId);
|
||||
}
|
||||
|
||||
@Put(':id')
|
||||
@ApiOperation({ summary: 'Update character' })
|
||||
@ApiParam({ name: 'id', description: 'Character ID' })
|
||||
@ApiResponse({ status: 200, description: 'Character updated', type: CharacterResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 403, description: 'Access denied' })
|
||||
@ApiResponse({ status: 404, description: 'Character not found' })
|
||||
async update(
|
||||
@Param('id') id: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
@Body() updateCharacterDto: UpdateCharacterDto,
|
||||
): Promise<Character> {
|
||||
return this.characterService.update(id, userId, updateCharacterDto);
|
||||
}
|
||||
|
||||
@Delete(':id')
|
||||
@ApiOperation({ summary: 'Delete character' })
|
||||
@ApiParam({ name: 'id', description: 'Character ID' })
|
||||
@ApiResponse({ status: 200, description: 'Character deleted' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 403, description: 'Access denied' })
|
||||
@ApiResponse({ status: 404, description: 'Character not found' })
|
||||
async delete(
|
||||
@Param('id') id: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<{ message: string }> {
|
||||
await this.characterService.delete(id, userId);
|
||||
return { message: 'Character deleted successfully' };
|
||||
}
|
||||
}
|
||||
10
apps/backend/src/character/character.module.ts
Normal file
10
apps/backend/src/character/character.module.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { CharacterService } from './character.service';
|
||||
import { CharacterController } from './character.controller';
|
||||
|
||||
@Module({
|
||||
providers: [CharacterService],
|
||||
controllers: [CharacterController],
|
||||
exports: [CharacterService],
|
||||
})
|
||||
export class CharacterModule {}
|
||||
74
apps/backend/src/character/character.service.ts
Normal file
74
apps/backend/src/character/character.service.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import { Injectable, NotFoundException, ForbiddenException } from '@nestjs/common';
|
||||
import { PrismaService } from '../prisma/prisma.service';
|
||||
import { CreateCharacterDto, UpdateCharacterDto } from './dto/create-character.dto';
|
||||
import { Character } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class CharacterService {
|
||||
constructor(private prisma: PrismaService) {}
|
||||
|
||||
async create(userId: string, createCharacterDto: CreateCharacterDto): Promise<Character> {
|
||||
return this.prisma.character.create({
|
||||
data: {
|
||||
...createCharacterDto,
|
||||
userId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async findAllByUser(userId: string): Promise<Character[]> {
|
||||
return this.prisma.character.findMany({
|
||||
where: { userId },
|
||||
orderBy: { createdAt: 'desc' },
|
||||
});
|
||||
}
|
||||
|
||||
async findById(id: string, userId?: string): Promise<Character> {
|
||||
const character = await this.prisma.character.findUnique({
|
||||
where: { id },
|
||||
include: {
|
||||
knowledgeSources: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (!character) {
|
||||
throw new NotFoundException('Character not found');
|
||||
}
|
||||
|
||||
// If userId is provided, check if user owns the character or if it's public
|
||||
if (userId && character.userId !== userId && !character.isPublic) {
|
||||
throw new ForbiddenException('You do not have access to this character');
|
||||
}
|
||||
|
||||
return character;
|
||||
}
|
||||
|
||||
async update(
|
||||
id: string,
|
||||
userId: string,
|
||||
updateCharacterDto: UpdateCharacterDto,
|
||||
): Promise<Character> {
|
||||
const character = await this.findById(id, userId);
|
||||
|
||||
if (character.userId !== userId) {
|
||||
throw new ForbiddenException('You can only update your own characters');
|
||||
}
|
||||
|
||||
return this.prisma.character.update({
|
||||
where: { id },
|
||||
data: updateCharacterDto,
|
||||
});
|
||||
}
|
||||
|
||||
async delete(id: string, userId: string): Promise<void> {
|
||||
const character = await this.findById(id, userId);
|
||||
|
||||
if (character.userId !== userId) {
|
||||
throw new ForbiddenException('You can only delete your own characters');
|
||||
}
|
||||
|
||||
await this.prisma.character.delete({
|
||||
where: { id },
|
||||
});
|
||||
}
|
||||
}
|
||||
38
apps/backend/src/character/dto/character-response.dto.ts
Normal file
38
apps/backend/src/character/dto/character-response.dto.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
|
||||
|
||||
export class CharacterResponseDto {
|
||||
@ApiProperty({ description: 'Character ID', example: '550e8400-e29b-41d4-a716-446655440000' })
|
||||
id: string;
|
||||
|
||||
@ApiProperty({ description: 'Character name', example: 'Alice the Explorer' })
|
||||
name: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Avatar URL', example: 'https://example.com/avatar.jpg' })
|
||||
avatarUrl: string | null;
|
||||
|
||||
@ApiProperty({ description: 'Personality prompt', example: 'You are Alice, a curious explorer...' })
|
||||
personalityPrompt: string;
|
||||
|
||||
@ApiProperty({ description: 'Custom attributes', example: { age: 25, traits: ['curious'] } })
|
||||
attributes: Record<string, any>;
|
||||
|
||||
@ApiProperty({ description: 'Character configuration', example: {} })
|
||||
config: Record<string, any>;
|
||||
|
||||
@ApiProperty({ description: 'Whether character is public', example: false })
|
||||
isPublic: boolean;
|
||||
|
||||
@ApiProperty({ description: 'Creation date' })
|
||||
createdAt: Date;
|
||||
|
||||
@ApiProperty({ description: 'Last update date' })
|
||||
updatedAt: Date;
|
||||
|
||||
@ApiProperty({ description: 'User ID', example: '550e8400-e29b-41d4-a716-446655440000' })
|
||||
userId: string;
|
||||
}
|
||||
|
||||
export class CharacterListResponseDto {
|
||||
@ApiProperty({ description: 'List of characters', type: [CharacterResponseDto] })
|
||||
characters: CharacterResponseDto[];
|
||||
}
|
||||
68
apps/backend/src/character/dto/create-character.dto.ts
Normal file
68
apps/backend/src/character/dto/create-character.dto.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import { IsString, IsOptional, IsBoolean, IsObject, MinLength } from 'class-validator';
|
||||
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
|
||||
|
||||
export class CreateCharacterDto {
|
||||
@ApiProperty({ description: 'Character name', example: 'Alice the Explorer' })
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
name: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Avatar URL', example: 'https://example.com/avatar.jpg' })
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
avatarUrl?: string;
|
||||
|
||||
@ApiProperty({ description: 'Personality prompt that guides AI responses', example: 'You are Alice, a curious and adventurous explorer...' })
|
||||
@IsString()
|
||||
@MinLength(10)
|
||||
personalityPrompt: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Custom attributes (JSON)', example: { age: 25, traits: ['curious', 'brave'] } })
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
attributes?: Record<string, any>;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Character configuration (JSON)', example: { voice: 'friendly' } })
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
config?: Record<string, any>;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Whether the character is publicly visible', example: false })
|
||||
@IsOptional()
|
||||
@IsBoolean()
|
||||
isPublic?: boolean;
|
||||
}
|
||||
|
||||
export class UpdateCharacterDto {
|
||||
@ApiPropertyOptional({ description: 'Character name', example: 'Alice the Explorer' })
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
name?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Avatar URL', example: 'https://example.com/avatar.jpg' })
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
avatarUrl?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Personality prompt', example: 'You are Alice...' })
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(10)
|
||||
personalityPrompt?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Custom attributes (JSON)' })
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
attributes?: Record<string, any>;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Character configuration (JSON)' })
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
config?: Record<string, any>;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Whether the character is publicly visible' })
|
||||
@IsOptional()
|
||||
@IsBoolean()
|
||||
isPublic?: boolean;
|
||||
}
|
||||
92
apps/backend/src/chat/chat.controller.ts
Normal file
92
apps/backend/src/chat/chat.controller.ts
Normal file
@@ -0,0 +1,92 @@
|
||||
import {
|
||||
Controller,
|
||||
Get,
|
||||
Post,
|
||||
Delete,
|
||||
Body,
|
||||
Param,
|
||||
HttpCode,
|
||||
HttpStatus,
|
||||
} from '@nestjs/common';
|
||||
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiParam } from '@nestjs/swagger';
|
||||
import { ChatService } from './chat.service';
|
||||
import { CreateConversationDto, SendMessageDto } from './dto/chat.dto';
|
||||
import {
|
||||
ConversationResponseDto,
|
||||
ConversationWithMessagesResponseDto,
|
||||
SendMessageResponseDto
|
||||
} from './dto/conversation-response.dto';
|
||||
import { CurrentUser } from '../common/decorators/current-user.decorator';
|
||||
import { Conversation, Message } from '@prisma/client';
|
||||
|
||||
@ApiTags('conversations')
|
||||
@ApiBearerAuth()
|
||||
@Controller('conversations')
|
||||
export class ChatController {
|
||||
constructor(private chatService: ChatService) {}
|
||||
|
||||
@Post()
|
||||
@ApiOperation({ summary: 'Create a new conversation' })
|
||||
@ApiResponse({ status: 201, description: 'Conversation created', type: ConversationResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 404, description: 'Character not found' })
|
||||
async createConversation(
|
||||
@CurrentUser('userId') userId: string,
|
||||
@Body() createConversationDto: CreateConversationDto,
|
||||
): Promise<Conversation> {
|
||||
return this.chatService.createConversation(userId, createConversationDto);
|
||||
}
|
||||
|
||||
@Get()
|
||||
@ApiOperation({ summary: 'Get all conversations for current user' })
|
||||
@ApiResponse({ status: 200, description: 'List of conversations', type: [ConversationResponseDto] })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
async getConversations(@CurrentUser('userId') userId: string): Promise<Conversation[]> {
|
||||
return this.chatService.findConversationsByUser(userId);
|
||||
}
|
||||
|
||||
@Get(':id')
|
||||
@ApiOperation({ summary: 'Get conversation by ID with messages' })
|
||||
@ApiParam({ name: 'id', description: 'Conversation ID' })
|
||||
@ApiResponse({ status: 200, description: 'Conversation found', type: ConversationWithMessagesResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 403, description: 'Access denied' })
|
||||
@ApiResponse({ status: 404, description: 'Conversation not found' })
|
||||
async getConversation(
|
||||
@Param('id') id: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<Conversation & { messages: Message[] }> {
|
||||
return this.chatService.findConversationById(id, userId);
|
||||
}
|
||||
|
||||
@Delete(':id')
|
||||
@ApiOperation({ summary: 'Delete conversation' })
|
||||
@ApiParam({ name: 'id', description: 'Conversation ID' })
|
||||
@ApiResponse({ status: 200, description: 'Conversation deleted' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 403, description: 'Access denied' })
|
||||
@ApiResponse({ status: 404, description: 'Conversation not found' })
|
||||
async deleteConversation(
|
||||
@Param('id') id: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<{ message: string }> {
|
||||
await this.chatService.deleteConversation(id, userId);
|
||||
return { message: 'Conversation deleted successfully' };
|
||||
}
|
||||
|
||||
@Post(':id/messages')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
@ApiOperation({ summary: 'Send a message in a conversation' })
|
||||
@ApiParam({ name: 'id', description: 'Conversation ID' })
|
||||
@ApiResponse({ status: 200, description: 'Message sent', type: SendMessageResponseDto })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 403, description: 'Access denied' })
|
||||
@ApiResponse({ status: 404, description: 'Conversation not found' })
|
||||
async sendMessage(
|
||||
@Param('id') conversationId: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
@Body() sendMessageDto: SendMessageDto,
|
||||
): Promise<SendMessageResponseDto> {
|
||||
return this.chatService.sendMessage(conversationId, userId, sendMessageDto);
|
||||
}
|
||||
}
|
||||
143
apps/backend/src/chat/chat.gateway.ts
Normal file
143
apps/backend/src/chat/chat.gateway.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
import {
|
||||
WebSocketGateway,
|
||||
WebSocketServer,
|
||||
SubscribeMessage,
|
||||
MessageBody,
|
||||
ConnectedSocket,
|
||||
OnGatewayConnection,
|
||||
OnGatewayDisconnect,
|
||||
} from '@nestjs/websockets';
|
||||
import { Server, Socket } from 'socket.io';
|
||||
import { UseGuards } from '@nestjs/common';
|
||||
import { ChatService } from './chat.service';
|
||||
import { JwtService } from '@nestjs/jwt';
|
||||
|
||||
interface AuthenticatedSocket extends Socket {
|
||||
userId?: string;
|
||||
}
|
||||
|
||||
@WebSocketGateway({
|
||||
cors: {
|
||||
origin: ['http://localhost:5173'],
|
||||
credentials: true,
|
||||
},
|
||||
namespace: '/chat',
|
||||
})
|
||||
export class ChatGateway implements OnGatewayConnection, OnGatewayDisconnect {
|
||||
@WebSocketServer()
|
||||
server: Server;
|
||||
|
||||
constructor(
|
||||
private chatService: ChatService,
|
||||
private jwtService: JwtService,
|
||||
) {}
|
||||
|
||||
async handleConnection(client: AuthenticatedSocket) {
|
||||
try {
|
||||
const token = client.handshake.auth.token as string;
|
||||
if (!token) {
|
||||
client.disconnect();
|
||||
return;
|
||||
}
|
||||
|
||||
const payload = this.jwtService.verify(token.replace('Bearer ', ''));
|
||||
client.userId = payload.sub;
|
||||
|
||||
console.log(`Client connected: ${client.id}, user: ${client.userId}`);
|
||||
} catch {
|
||||
client.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
handleDisconnect(client: AuthenticatedSocket) {
|
||||
console.log(`Client disconnected: ${client.id}`);
|
||||
}
|
||||
|
||||
@SubscribeMessage('join_conversation')
|
||||
async handleJoinConversation(
|
||||
@MessageBody() data: { conversationId: string },
|
||||
@ConnectedSocket() client: AuthenticatedSocket,
|
||||
) {
|
||||
if (!client.userId) return;
|
||||
|
||||
const room = `conversation:${data.conversationId}`;
|
||||
await client.join(room);
|
||||
client.emit('joined', { conversationId: data.conversationId });
|
||||
}
|
||||
|
||||
@SubscribeMessage('leave_conversation')
|
||||
async handleLeaveConversation(
|
||||
@MessageBody() data: { conversationId: string },
|
||||
@ConnectedSocket() client: AuthenticatedSocket,
|
||||
) {
|
||||
const room = `conversation:${data.conversationId}`;
|
||||
await client.leave(room);
|
||||
client.emit('left', { conversationId: data.conversationId });
|
||||
}
|
||||
|
||||
@SubscribeMessage('send_message')
|
||||
async handleSendMessage(
|
||||
@MessageBody()
|
||||
data: { conversationId: string; content: string },
|
||||
@ConnectedSocket() client: AuthenticatedSocket,
|
||||
) {
|
||||
if (!client.userId) return;
|
||||
|
||||
const room = `conversation:${data.conversationId}`;
|
||||
|
||||
try {
|
||||
// Stream the response
|
||||
const stream = this.chatService.streamMessage(
|
||||
data.conversationId,
|
||||
client.userId,
|
||||
{ content: data.content },
|
||||
);
|
||||
|
||||
let assistantMessage: any = null;
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === 'chunk') {
|
||||
// Broadcast chunk to all clients in the room
|
||||
this.server.to(room).emit('message_chunk', {
|
||||
conversationId: data.conversationId,
|
||||
chunk: event.data,
|
||||
});
|
||||
} else if (event.type === 'message') {
|
||||
if (event.data.assistantMessage) {
|
||||
assistantMessage = event.data.assistantMessage;
|
||||
}
|
||||
// Broadcast the full message
|
||||
this.server.to(room).emit('message', {
|
||||
conversationId: data.conversationId,
|
||||
message: event.data,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Signal completion
|
||||
this.server.to(room).emit('message_complete', {
|
||||
conversationId: data.conversationId,
|
||||
assistantMessage,
|
||||
});
|
||||
} catch (error) {
|
||||
client.emit('error', {
|
||||
conversationId: data.conversationId,
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@SubscribeMessage('typing')
|
||||
async handleTyping(
|
||||
@MessageBody() data: { conversationId: string; isTyping: boolean },
|
||||
@ConnectedSocket() client: AuthenticatedSocket,
|
||||
) {
|
||||
if (!client.userId) return;
|
||||
|
||||
const room = `conversation:${data.conversationId}`;
|
||||
client.to(room).emit('user_typing', {
|
||||
conversationId: data.conversationId,
|
||||
isTyping: data.isTyping,
|
||||
});
|
||||
}
|
||||
}
|
||||
16
apps/backend/src/chat/chat.module.ts
Normal file
16
apps/backend/src/chat/chat.module.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { JwtModule } from '@nestjs/jwt';
|
||||
import { ChatService } from './chat.service';
|
||||
import { ChatController } from './chat.controller';
|
||||
import { ChatGateway } from './chat.gateway';
|
||||
import { LLMModule } from '../llm/llm.module';
|
||||
import { VectorModule } from '../vector/vector.module';
|
||||
import { CharacterModule } from '../character/character.module';
|
||||
|
||||
@Module({
|
||||
imports: [LLMModule, VectorModule, CharacterModule, JwtModule],
|
||||
providers: [ChatService, ChatGateway],
|
||||
controllers: [ChatController],
|
||||
exports: [ChatService],
|
||||
})
|
||||
export class ChatModule {}
|
||||
288
apps/backend/src/chat/chat.service.ts
Normal file
288
apps/backend/src/chat/chat.service.ts
Normal file
@@ -0,0 +1,288 @@
|
||||
import { Injectable, NotFoundException, ForbiddenException } from '@nestjs/common';
|
||||
import { PrismaService } from '../prisma/prisma.service';
|
||||
import { LLMService } from '../llm/llm.service';
|
||||
import { MemoryService } from '../vector/memory.service';
|
||||
import { CharacterService } from '../character/character.service';
|
||||
import { CreateConversationDto, SendMessageDto } from './dto/chat.dto';
|
||||
import { Conversation, Message, MessageRole } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class ChatService {
|
||||
constructor(
|
||||
private prisma: PrismaService,
|
||||
private llmService: LLMService,
|
||||
private memoryService: MemoryService,
|
||||
private characterService: CharacterService,
|
||||
) {}
|
||||
|
||||
async createConversation(
|
||||
userId: string,
|
||||
createConversationDto: CreateConversationDto,
|
||||
): Promise<Conversation> {
|
||||
// Verify character exists and user has access
|
||||
const character = await this.characterService.findById(
|
||||
createConversationDto.characterId,
|
||||
userId,
|
||||
);
|
||||
|
||||
return this.prisma.conversation.create({
|
||||
data: {
|
||||
userId,
|
||||
characterId: createConversationDto.characterId,
|
||||
title: createConversationDto.title || `Chat with ${character.name}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async findConversationsByUser(userId: string): Promise<Conversation[]> {
|
||||
return this.prisma.conversation.findMany({
|
||||
where: { userId },
|
||||
include: {
|
||||
character: {
|
||||
select: {
|
||||
id: true,
|
||||
name: true,
|
||||
avatarUrl: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
orderBy: { updatedAt: 'desc' },
|
||||
});
|
||||
}
|
||||
|
||||
async findConversationById(
|
||||
id: string,
|
||||
userId: string,
|
||||
): Promise<Conversation & { messages: Message[]; character: { id: string; name: string; personalityPrompt: string } }> {
|
||||
const conversation = await this.prisma.conversation.findUnique({
|
||||
where: { id },
|
||||
include: {
|
||||
messages: {
|
||||
orderBy: { createdAt: 'asc' },
|
||||
},
|
||||
character: {
|
||||
select: {
|
||||
id: true,
|
||||
name: true,
|
||||
personalityPrompt: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!conversation) {
|
||||
throw new NotFoundException('Conversation not found');
|
||||
}
|
||||
|
||||
if (conversation.userId !== userId) {
|
||||
throw new ForbiddenException('You do not have access to this conversation');
|
||||
}
|
||||
|
||||
return conversation;
|
||||
}
|
||||
|
||||
async deleteConversation(id: string, userId: string): Promise<void> {
|
||||
const conversation = await this.findConversationById(id, userId);
|
||||
|
||||
if (conversation.userId !== userId) {
|
||||
throw new ForbiddenException('You can only delete your own conversations');
|
||||
}
|
||||
|
||||
// Delete related vector memories first
|
||||
await this.memoryService['vectorStore'].deleteByConversation(id);
|
||||
|
||||
await this.prisma.conversation.delete({
|
||||
where: { id },
|
||||
});
|
||||
}
|
||||
|
||||
async sendMessage(
|
||||
conversationId: string,
|
||||
userId: string,
|
||||
sendMessageDto: SendMessageDto,
|
||||
): Promise<{ userMessage: Message; assistantMessage: Message }> {
|
||||
const conversation = await this.findConversationById(conversationId, userId);
|
||||
|
||||
// Create user message
|
||||
const userMessage = await this.prisma.message.create({
|
||||
data: {
|
||||
conversationId,
|
||||
role: 'user',
|
||||
content: sendMessageDto.content,
|
||||
},
|
||||
});
|
||||
|
||||
// Store user message in vector memory
|
||||
await this.memoryService.storeConversationMessage(
|
||||
`User: ${sendMessageDto.content}`,
|
||||
conversationId,
|
||||
{ messageId: userMessage.id },
|
||||
);
|
||||
|
||||
// Generate context from memory
|
||||
const memoryContext = await this.memoryService.buildContextForConversation(
|
||||
conversationId,
|
||||
sendMessageDto.content,
|
||||
conversation.characterId,
|
||||
);
|
||||
|
||||
// Build messages for LLM
|
||||
const messages = this.buildLLMMessages(
|
||||
conversation.character.personalityPrompt,
|
||||
conversation.messages,
|
||||
sendMessageDto.content,
|
||||
memoryContext,
|
||||
);
|
||||
|
||||
// Generate response
|
||||
const response = await this.llmService.generateCompletion(messages, {
|
||||
temperature: 0.7,
|
||||
maxTokens: 2000,
|
||||
});
|
||||
|
||||
// Create assistant message
|
||||
const assistantMessage = await this.prisma.message.create({
|
||||
data: {
|
||||
conversationId,
|
||||
role: 'assistant',
|
||||
content: response.content,
|
||||
tokensUsed: response.tokensUsed,
|
||||
model: response.model,
|
||||
},
|
||||
});
|
||||
|
||||
// Update conversation stats
|
||||
await this.prisma.conversation.update({
|
||||
where: { id: conversationId },
|
||||
data: {
|
||||
messageCount: { increment: 2 },
|
||||
totalTokens: { increment: response.tokensUsed },
|
||||
},
|
||||
});
|
||||
|
||||
// Store assistant response in vector memory
|
||||
await this.memoryService.storeConversationMessage(
|
||||
`${conversation.character.name}: ${response.content}`,
|
||||
conversationId,
|
||||
{ messageId: assistantMessage.id },
|
||||
);
|
||||
|
||||
return { userMessage, assistantMessage };
|
||||
}
|
||||
|
||||
async *streamMessage(
|
||||
conversationId: string,
|
||||
userId: string,
|
||||
sendMessageDto: SendMessageDto,
|
||||
): AsyncGenerator<{ type: 'chunk' | 'message'; data: any }> {
|
||||
const conversation = await this.findConversationById(conversationId, userId);
|
||||
|
||||
// Create user message
|
||||
const userMessage = await this.prisma.message.create({
|
||||
data: {
|
||||
conversationId,
|
||||
role: 'user',
|
||||
content: sendMessageDto.content,
|
||||
},
|
||||
});
|
||||
|
||||
yield { type: 'message', data: { userMessage } };
|
||||
|
||||
// Store user message in vector memory
|
||||
await this.memoryService.storeConversationMessage(
|
||||
`User: ${sendMessageDto.content}`,
|
||||
conversationId,
|
||||
{ messageId: userMessage.id },
|
||||
);
|
||||
|
||||
// Generate context from memory
|
||||
const memoryContext = await this.memoryService.buildContextForConversation(
|
||||
conversationId,
|
||||
sendMessageDto.content,
|
||||
conversation.characterId,
|
||||
);
|
||||
|
||||
// Build messages for LLM
|
||||
const messages = this.buildLLMMessages(
|
||||
conversation.character.personalityPrompt,
|
||||
conversation.messages,
|
||||
sendMessageDto.content,
|
||||
memoryContext,
|
||||
);
|
||||
|
||||
// Generate streaming response
|
||||
let fullContent = '';
|
||||
const stream = this.llmService.generateStream(messages, {
|
||||
temperature: 0.7,
|
||||
maxTokens: 2000,
|
||||
});
|
||||
|
||||
for await (const chunk of stream) {
|
||||
fullContent += chunk.content;
|
||||
yield { type: 'chunk', data: chunk };
|
||||
}
|
||||
|
||||
// Create assistant message
|
||||
const assistantMessage = await this.prisma.message.create({
|
||||
data: {
|
||||
conversationId,
|
||||
role: 'assistant',
|
||||
content: fullContent,
|
||||
model: process.env.LLM_MODEL || 'openai/gpt-4o',
|
||||
},
|
||||
});
|
||||
|
||||
// Update conversation stats
|
||||
const tokensUsed = this.llmService.countTokens([
|
||||
...messages,
|
||||
{ role: 'assistant', content: fullContent },
|
||||
]);
|
||||
|
||||
await this.prisma.conversation.update({
|
||||
where: { id: conversationId },
|
||||
data: {
|
||||
messageCount: { increment: 2 },
|
||||
totalTokens: { increment: tokensUsed },
|
||||
},
|
||||
});
|
||||
|
||||
// Store assistant response in vector memory
|
||||
await this.memoryService.storeConversationMessage(
|
||||
`${conversation.character.name}: ${fullContent}`,
|
||||
conversationId,
|
||||
{ messageId: assistantMessage.id },
|
||||
);
|
||||
|
||||
yield { type: 'message', data: { assistantMessage } };
|
||||
}
|
||||
|
||||
private buildLLMMessages(
|
||||
personalityPrompt: string | null,
|
||||
history: Message[],
|
||||
currentMessage: string,
|
||||
memoryContext: string,
|
||||
): Array<{ role: 'system' | 'user' | 'assistant'; content: string }> {
|
||||
const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = [];
|
||||
|
||||
// Add system message with personality and context
|
||||
let systemContent = personalityPrompt;
|
||||
if (memoryContext) {
|
||||
systemContent += `\n\nUse the following context to inform your responses:\n${memoryContext}`;
|
||||
}
|
||||
messages.push({ role: 'system', content: systemContent });
|
||||
|
||||
// Add recent conversation history (last 10 messages)
|
||||
const recentHistory = history.slice(-10);
|
||||
for (const message of recentHistory) {
|
||||
messages.push({
|
||||
role: message.role.toLowerCase() as 'user' | 'assistant',
|
||||
content: message.content,
|
||||
});
|
||||
}
|
||||
|
||||
// Add current message
|
||||
messages.push({ role: 'user', content: currentMessage });
|
||||
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
25
apps/backend/src/chat/dto/chat.dto.ts
Normal file
25
apps/backend/src/chat/dto/chat.dto.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { IsString, IsOptional } from 'class-validator';
|
||||
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
|
||||
|
||||
export class CreateConversationDto {
|
||||
@ApiProperty({ description: 'Character ID to chat with', example: '550e8400-e29b-41d4-a716-446655440000' })
|
||||
@IsString()
|
||||
characterId: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Conversation title', example: 'My Adventure with Alice' })
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
title?: string;
|
||||
}
|
||||
|
||||
export class SendMessageDto {
|
||||
@ApiProperty({ description: 'Message content', example: 'Hello! Tell me about yourself.' })
|
||||
@IsString()
|
||||
content: string;
|
||||
}
|
||||
|
||||
export class UpdateMessageDto {
|
||||
@ApiProperty({ description: 'Updated message content' })
|
||||
@IsString()
|
||||
content: string;
|
||||
}
|
||||
72
apps/backend/src/chat/dto/conversation-response.dto.ts
Normal file
72
apps/backend/src/chat/dto/conversation-response.dto.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
|
||||
import { MessageRole } from '@prisma/client';
|
||||
|
||||
export class MessageResponseDto {
|
||||
@ApiProperty({ description: 'Message ID' })
|
||||
id: string;
|
||||
|
||||
@ApiProperty({ description: 'Message role', enum: MessageRole })
|
||||
role: MessageRole;
|
||||
|
||||
@ApiProperty({ description: 'Message content' })
|
||||
content: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Tokens used' })
|
||||
tokensUsed: number | null;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Model used' })
|
||||
model: string | null;
|
||||
|
||||
@ApiProperty({ description: 'Creation date' })
|
||||
createdAt: Date;
|
||||
}
|
||||
|
||||
export class CharacterSummaryDto {
|
||||
@ApiProperty({ description: 'Character ID' })
|
||||
id: string;
|
||||
|
||||
@ApiProperty({ description: 'Character name' })
|
||||
name: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Avatar URL' })
|
||||
avatarUrl: string | null;
|
||||
}
|
||||
|
||||
export class ConversationResponseDto {
|
||||
@ApiProperty({ description: 'Conversation ID' })
|
||||
id: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Conversation title' })
|
||||
title: string | null;
|
||||
|
||||
@ApiProperty({ description: 'Character ID' })
|
||||
characterId: string;
|
||||
|
||||
@ApiProperty({ description: 'Number of messages' })
|
||||
messageCount: number;
|
||||
|
||||
@ApiProperty({ description: 'Total tokens used' })
|
||||
totalTokens: number;
|
||||
|
||||
@ApiProperty({ description: 'Creation date' })
|
||||
createdAt: Date;
|
||||
|
||||
@ApiProperty({ description: 'Last update date' })
|
||||
updatedAt: Date;
|
||||
|
||||
@ApiPropertyOptional({ description: 'Character info', type: CharacterSummaryDto })
|
||||
character?: CharacterSummaryDto;
|
||||
}
|
||||
|
||||
export class ConversationWithMessagesResponseDto extends ConversationResponseDto {
|
||||
@ApiProperty({ description: 'Messages in conversation', type: [MessageResponseDto] })
|
||||
messages: MessageResponseDto[];
|
||||
}
|
||||
|
||||
export class SendMessageResponseDto {
|
||||
@ApiProperty({ description: 'User message', type: MessageResponseDto })
|
||||
userMessage: MessageResponseDto;
|
||||
|
||||
@ApiProperty({ description: 'Assistant response', type: MessageResponseDto })
|
||||
assistantMessage: MessageResponseDto;
|
||||
}
|
||||
10
apps/backend/src/common/decorators/current-user.decorator.ts
Normal file
10
apps/backend/src/common/decorators/current-user.decorator.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { createParamDecorator, ExecutionContext } from '@nestjs/common';
|
||||
|
||||
export const CurrentUser = createParamDecorator(
|
||||
(data: keyof any | undefined, ctx: ExecutionContext) => {
|
||||
const request = ctx.switchToHttp().getRequest();
|
||||
const user = request.user;
|
||||
|
||||
return data ? user?.[data] : user;
|
||||
},
|
||||
);
|
||||
4
apps/backend/src/common/decorators/public.decorator.ts
Normal file
4
apps/backend/src/common/decorators/public.decorator.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
import { SetMetadata } from '@nestjs/common';
|
||||
|
||||
export const IS_PUBLIC_KEY = 'isPublic';
|
||||
export const Public = () => SetMetadata(IS_PUBLIC_KEY, true);
|
||||
40
apps/backend/src/import/adapters/text-file.adapter.ts
Normal file
40
apps/backend/src/import/adapters/text-file.adapter.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { IImportAdapter, ImportResult } from '../interfaces/import-adapter.interface';
|
||||
|
||||
export class TextFileAdapter implements IImportAdapter {
|
||||
private readonly supportedMimeTypes = [
|
||||
'text/plain',
|
||||
'text/markdown',
|
||||
'text/x-markdown',
|
||||
'application/octet-stream',
|
||||
];
|
||||
|
||||
private readonly supportedExtensions = ['.txt', '.md', '.markdown'];
|
||||
|
||||
canHandle(file: Express.Multer.File): boolean {
|
||||
// Check MIME type
|
||||
if (this.supportedMimeTypes.includes(file.mimetype)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
const ext = file.originalname.toLowerCase().slice(file.originalname.lastIndexOf('.'));
|
||||
if (this.supportedExtensions.includes(ext)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
async parse(file: Express.Multer.File): Promise<ImportResult> {
|
||||
const content = file.buffer.toString('utf-8');
|
||||
|
||||
return {
|
||||
content,
|
||||
metadata: {
|
||||
sourceName: file.originalname,
|
||||
mimeType: file.mimetype,
|
||||
fileSize: file.size,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
101
apps/backend/src/import/import.controller.ts
Normal file
101
apps/backend/src/import/import.controller.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import {
|
||||
Controller,
|
||||
Post,
|
||||
Get,
|
||||
Delete,
|
||||
Param,
|
||||
UploadedFile,
|
||||
UseInterceptors,
|
||||
BadRequestException,
|
||||
} from '@nestjs/common';
|
||||
import { FileInterceptor } from '@nestjs/platform-express';
|
||||
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth, ApiParam, ApiConsumes, ApiBody, ApiProperty } from '@nestjs/swagger';
|
||||
import { ImportService } from './import.service';
|
||||
import { CurrentUser } from '../common/decorators/current-user.decorator';
|
||||
import { CharacterKnowledge } from '@prisma/client';
|
||||
|
||||
class UploadResponseDto {
|
||||
@ApiProperty({ description: 'Knowledge ID' })
|
||||
knowledgeId: string;
|
||||
|
||||
@ApiProperty({ description: 'Status message' })
|
||||
message: string;
|
||||
}
|
||||
|
||||
@ApiTags('import')
|
||||
@ApiBearerAuth()
|
||||
@Controller('import')
|
||||
export class ImportController {
|
||||
constructor(private importService: ImportService) {}
|
||||
|
||||
@Post('characters/:characterId/files')
|
||||
@ApiOperation({ summary: 'Upload a file for character knowledge' })
|
||||
@ApiParam({ name: 'characterId', description: 'Character ID' })
|
||||
@ApiConsumes('multipart/form-data')
|
||||
@ApiBody({
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
file: {
|
||||
type: 'string',
|
||||
format: 'binary',
|
||||
description: 'File to upload (.txt, .md)',
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ApiResponse({ status: 201, description: 'File uploaded and processing' })
|
||||
@ApiResponse({ status: 400, description: 'Invalid file type' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@UseInterceptors(FileInterceptor('file'))
|
||||
async uploadFile(
|
||||
@Param('characterId') characterId: string,
|
||||
@UploadedFile() file: Express.Multer.File,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<UploadResponseDto> {
|
||||
if (!file) {
|
||||
throw new BadRequestException('No file uploaded');
|
||||
}
|
||||
|
||||
return this.importService.uploadFile(file, characterId, userId);
|
||||
}
|
||||
|
||||
@Get('knowledge/:knowledgeId/status')
|
||||
@ApiOperation({ summary: 'Get knowledge processing status' })
|
||||
@ApiParam({ name: 'knowledgeId', description: 'Knowledge ID' })
|
||||
@ApiResponse({ status: 200, description: 'Knowledge status' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 404, description: 'Knowledge not found' })
|
||||
async getKnowledgeStatus(
|
||||
@Param('knowledgeId') knowledgeId: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<CharacterKnowledge> {
|
||||
return this.importService.getKnowledgeStatus(knowledgeId, userId);
|
||||
}
|
||||
|
||||
@Get('characters/:characterId/knowledge')
|
||||
@ApiOperation({ summary: 'Get all knowledge for a character' })
|
||||
@ApiParam({ name: 'characterId', description: 'Character ID' })
|
||||
@ApiResponse({ status: 200, description: 'List of knowledge' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
async getCharacterKnowledge(
|
||||
@Param('characterId') characterId: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<CharacterKnowledge[]> {
|
||||
return this.importService.getCharacterKnowledge(characterId, userId);
|
||||
}
|
||||
|
||||
@Delete('knowledge/:knowledgeId')
|
||||
@ApiOperation({ summary: 'Delete knowledge' })
|
||||
@ApiParam({ name: 'knowledgeId', description: 'Knowledge ID' })
|
||||
@ApiResponse({ status: 200, description: 'Knowledge deleted' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 404, description: 'Knowledge not found' })
|
||||
async deleteKnowledge(
|
||||
@Param('knowledgeId') knowledgeId: string,
|
||||
@CurrentUser('userId') userId: string,
|
||||
): Promise<{ message: string }> {
|
||||
await this.importService.deleteKnowledge(knowledgeId, userId);
|
||||
return { message: 'Knowledge deleted successfully' };
|
||||
}
|
||||
}
|
||||
12
apps/backend/src/import/import.module.ts
Normal file
12
apps/backend/src/import/import.module.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { ImportService } from './import.service';
|
||||
import { ImportController } from './import.controller';
|
||||
import { VectorModule } from '../vector/vector.module';
|
||||
|
||||
@Module({
|
||||
imports: [VectorModule],
|
||||
providers: [ImportService],
|
||||
controllers: [ImportController],
|
||||
exports: [ImportService],
|
||||
})
|
||||
export class ImportModule {}
|
||||
214
apps/backend/src/import/import.service.ts
Normal file
214
apps/backend/src/import/import.service.ts
Normal file
@@ -0,0 +1,214 @@
|
||||
import { Injectable, BadRequestException } from '@nestjs/common';
|
||||
import { PrismaService } from '../prisma/prisma.service';
|
||||
import { MemoryService } from '../vector/memory.service';
|
||||
import { TextFileAdapter } from './adapters/text-file.adapter';
|
||||
import { IImportAdapter, ImportResult } from './interfaces/import-adapter.interface';
|
||||
import { ImportStatus } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class ImportService {
|
||||
private adapters: IImportAdapter[];
|
||||
|
||||
constructor(
|
||||
private prisma: PrismaService,
|
||||
private memoryService: MemoryService,
|
||||
) {
|
||||
this.adapters = [new TextFileAdapter()];
|
||||
}
|
||||
|
||||
async uploadFile(
|
||||
file: Express.Multer.File,
|
||||
characterId: string,
|
||||
userId: string,
|
||||
): Promise<{ knowledgeId: string; message: string }> {
|
||||
// Find appropriate adapter
|
||||
const adapter = this.adapters.find((a) => a.canHandle(file));
|
||||
|
||||
if (!adapter) {
|
||||
throw new BadRequestException(
|
||||
`Unsupported file type: ${file.mimetype}. Supported types: .txt, .md`,
|
||||
);
|
||||
}
|
||||
|
||||
// Parse the file
|
||||
const result = await adapter.parse(file);
|
||||
|
||||
// Create knowledge entry
|
||||
const knowledge = await this.prisma.characterKnowledge.create({
|
||||
data: {
|
||||
name: file.originalname,
|
||||
sourceType: 'file',
|
||||
sourceName: file.originalname,
|
||||
mimeType: file.mimetype,
|
||||
fileSize: BigInt(file.size),
|
||||
rawContent: result.content,
|
||||
status: 'processing',
|
||||
processingInfo: result.metadata,
|
||||
characterId,
|
||||
},
|
||||
});
|
||||
|
||||
// Process the content in the background
|
||||
this.processContent(knowledge.id, characterId, result).catch((error) => {
|
||||
console.error('Error processing import:', error);
|
||||
this.prisma.characterKnowledge.update({
|
||||
where: { id: knowledge.id },
|
||||
data: {
|
||||
status: 'failed',
|
||||
processingInfo: {
|
||||
...result.metadata,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
knowledgeId: knowledge.id,
|
||||
message: 'File uploaded and is being processed',
|
||||
};
|
||||
}
|
||||
|
||||
async getKnowledgeStatus(knowledgeId: string, userId: string) {
|
||||
const knowledge = await this.prisma.characterKnowledge.findFirst({
|
||||
where: {
|
||||
id: knowledgeId,
|
||||
character: {
|
||||
userId,
|
||||
},
|
||||
},
|
||||
include: {
|
||||
character: {
|
||||
select: {
|
||||
id: true,
|
||||
name: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!knowledge) {
|
||||
throw new BadRequestException('Knowledge not found');
|
||||
}
|
||||
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
async deleteKnowledge(knowledgeId: string, userId: string): Promise<void> {
|
||||
const knowledge = await this.prisma.characterKnowledge.findFirst({
|
||||
where: {
|
||||
id: knowledgeId,
|
||||
character: {
|
||||
userId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!knowledge) {
|
||||
throw new BadRequestException('Knowledge not found');
|
||||
}
|
||||
|
||||
// Delete associated vector memories
|
||||
await this.memoryService['vectorStore'].deleteByKnowledge(knowledgeId);
|
||||
|
||||
// Delete the knowledge entry
|
||||
await this.prisma.characterKnowledge.delete({
|
||||
where: { id: knowledgeId },
|
||||
});
|
||||
}
|
||||
|
||||
async getCharacterKnowledge(characterId: string, userId: string) {
|
||||
// Verify user owns the character
|
||||
const character = await this.prisma.character.findFirst({
|
||||
where: {
|
||||
id: characterId,
|
||||
userId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!character) {
|
||||
throw new BadRequestException('Character not found');
|
||||
}
|
||||
|
||||
return this.prisma.characterKnowledge.findMany({
|
||||
where: { characterId },
|
||||
orderBy: { createdAt: 'desc' },
|
||||
});
|
||||
}
|
||||
|
||||
private async processContent(
|
||||
knowledgeId: string,
|
||||
characterId: string,
|
||||
result: ImportResult,
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Chunk the content into smaller pieces
|
||||
const chunks = this.chunkContent(result.content, 1000, 200);
|
||||
|
||||
// Store each chunk in vector memory
|
||||
for (let i = 0; i < chunks.length; i++) {
|
||||
await this.memoryService.storeCharacterKnowledge(
|
||||
chunks[i],
|
||||
characterId,
|
||||
knowledgeId,
|
||||
{
|
||||
...result.metadata,
|
||||
chunkIndex: i,
|
||||
totalChunks: chunks.length,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Update status to completed
|
||||
await this.prisma.characterKnowledge.update({
|
||||
where: { id: knowledgeId },
|
||||
data: {
|
||||
status: 'completed',
|
||||
processingInfo: {
|
||||
...result.metadata,
|
||||
chunksProcessed: chunks.length,
|
||||
},
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
// Update status to failed
|
||||
await this.prisma.characterKnowledge.update({
|
||||
where: { id: knowledgeId },
|
||||
data: {
|
||||
status: 'failed',
|
||||
processingInfo: {
|
||||
...result.metadata,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
},
|
||||
},
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private chunkContent(content: string, chunkSize: number, overlap: number): string[] {
|
||||
const chunks: string[] = [];
|
||||
let start = 0;
|
||||
|
||||
while (start < content.length) {
|
||||
const end = Math.min(start + chunkSize, content.length);
|
||||
let chunk = content.slice(start, end);
|
||||
|
||||
// Try to break at a sentence boundary
|
||||
if (end < content.length) {
|
||||
const lastPeriod = chunk.lastIndexOf('.');
|
||||
const lastNewline = chunk.lastIndexOf('\n');
|
||||
const breakPoint = Math.max(lastPeriod, lastNewline);
|
||||
|
||||
if (breakPoint > chunkSize * 0.5) {
|
||||
chunk = chunk.slice(0, breakPoint + 1);
|
||||
}
|
||||
}
|
||||
|
||||
chunks.push(chunk.trim());
|
||||
start += chunk.length - overlap;
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
export interface ImportResult {
|
||||
content: string;
|
||||
metadata: {
|
||||
sourceName: string;
|
||||
mimeType: string;
|
||||
fileSize?: number;
|
||||
[key: string]: any;
|
||||
};
|
||||
}
|
||||
|
||||
export interface IImportAdapter {
|
||||
canHandle(file: Express.Multer.File): boolean;
|
||||
parse(file: Express.Multer.File): Promise<ImportResult>;
|
||||
}
|
||||
40
apps/backend/src/llm/interfaces/llm-provider.interface.ts
Normal file
40
apps/backend/src/llm/interfaces/llm-provider.interface.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
export interface LLMMessage {
|
||||
role: 'system' | 'user' | 'assistant';
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface LLMCompletionOptions {
|
||||
model?: string;
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
topP?: number;
|
||||
frequencyPenalty?: number;
|
||||
presencePenalty?: number;
|
||||
stop?: string[];
|
||||
}
|
||||
|
||||
export interface LLMCompletionResponse {
|
||||
content: string;
|
||||
model: string;
|
||||
tokensUsed: number;
|
||||
finishReason: string;
|
||||
}
|
||||
|
||||
export interface LLMStreamChunk {
|
||||
content: string;
|
||||
isDone: boolean;
|
||||
}
|
||||
|
||||
export interface ILLMProvider {
|
||||
generateCompletion(
|
||||
messages: LLMMessage[],
|
||||
options?: LLMCompletionOptions,
|
||||
): Promise<LLMCompletionResponse>;
|
||||
|
||||
generateStream(
|
||||
messages: LLMMessage[],
|
||||
options?: LLMCompletionOptions,
|
||||
): AsyncGenerator<LLMStreamChunk>;
|
||||
|
||||
countTokens(messages: LLMMessage[]): number;
|
||||
}
|
||||
8
apps/backend/src/llm/llm.module.ts
Normal file
8
apps/backend/src/llm/llm.module.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { LLMService } from './llm.service';
|
||||
|
||||
@Module({
|
||||
providers: [LLMService],
|
||||
exports: [LLMService],
|
||||
})
|
||||
export class LLMModule {}
|
||||
38
apps/backend/src/llm/llm.service.ts
Normal file
38
apps/backend/src/llm/llm.service.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import {
|
||||
ILLMProvider,
|
||||
LLMMessage,
|
||||
LLMCompletionOptions,
|
||||
LLMCompletionResponse,
|
||||
LLMStreamChunk,
|
||||
} from './interfaces/llm-provider.interface';
|
||||
import { OpenRouterProvider } from './providers/openrouter.provider';
|
||||
|
||||
@Injectable()
|
||||
export class LLMService {
|
||||
private provider: ILLMProvider;
|
||||
|
||||
constructor() {
|
||||
// For now, only OpenRouter is supported
|
||||
// In the future, this could be configurable
|
||||
this.provider = new OpenRouterProvider();
|
||||
}
|
||||
|
||||
async generateCompletion(
|
||||
messages: LLMMessage[],
|
||||
options?: LLMCompletionOptions,
|
||||
): Promise<LLMCompletionResponse> {
|
||||
return this.provider.generateCompletion(messages, options);
|
||||
}
|
||||
|
||||
async *generateStream(
|
||||
messages: LLMMessage[],
|
||||
options?: LLMCompletionOptions,
|
||||
): AsyncGenerator<LLMStreamChunk> {
|
||||
yield* this.provider.generateStream(messages, options);
|
||||
}
|
||||
|
||||
countTokens(messages: LLMMessage[]): number {
|
||||
return this.provider.countTokens(messages);
|
||||
}
|
||||
}
|
||||
140
apps/backend/src/llm/providers/openrouter.provider.ts
Normal file
140
apps/backend/src/llm/providers/openrouter.provider.ts
Normal file
@@ -0,0 +1,140 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import {
|
||||
ILLMProvider,
|
||||
LLMMessage,
|
||||
LLMCompletionOptions,
|
||||
LLMCompletionResponse,
|
||||
LLMStreamChunk,
|
||||
} from '../interfaces/llm-provider.interface';
|
||||
|
||||
@Injectable()
|
||||
export class OpenRouterProvider implements ILLMProvider {
|
||||
private readonly apiKey: string;
|
||||
private readonly baseUrl = 'https://openrouter.ai/api/v1';
|
||||
private readonly defaultModel: string;
|
||||
|
||||
constructor() {
|
||||
this.apiKey = process.env.LLM_API_KEY || '';
|
||||
this.defaultModel = process.env.LLM_MODEL || 'openai/gpt-4o';
|
||||
|
||||
if (!this.apiKey) {
|
||||
console.warn('LLM_API_KEY not set. OpenRouter provider will not work.');
|
||||
}
|
||||
}
|
||||
|
||||
async generateCompletion(
|
||||
messages: LLMMessage[],
|
||||
options?: LLMCompletionOptions,
|
||||
): Promise<LLMCompletionResponse> {
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'HTTP-Referer': process.env.APP_URL || 'http://localhost:3000',
|
||||
'X-Title': 'DreamChat',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: options?.model || this.defaultModel,
|
||||
messages,
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens,
|
||||
top_p: options?.topP,
|
||||
frequency_penalty: options?.frequencyPenalty,
|
||||
presence_penalty: options?.presencePenalty,
|
||||
stop: options?.stop,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`OpenRouter API error: ${error}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const choice = data.choices[0];
|
||||
|
||||
return {
|
||||
content: choice.message.content,
|
||||
model: data.model,
|
||||
tokensUsed: data.usage?.total_tokens || 0,
|
||||
finishReason: choice.finish_reason,
|
||||
};
|
||||
}
|
||||
|
||||
async *generateStream(
|
||||
messages: LLMMessage[],
|
||||
options?: LLMCompletionOptions,
|
||||
): AsyncGenerator<LLMStreamChunk> {
|
||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'HTTP-Referer': process.env.APP_URL || 'http://localhost:3000',
|
||||
'X-Title': 'DreamChat',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: options?.model || this.defaultModel,
|
||||
messages,
|
||||
temperature: options?.temperature ?? 0.7,
|
||||
max_tokens: options?.maxTokens,
|
||||
top_p: options?.topP,
|
||||
frequency_penalty: options?.frequencyPenalty,
|
||||
presence_penalty: options?.presencePenalty,
|
||||
stop: options?.stop,
|
||||
stream: true,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`OpenRouter API error: ${error}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error('No response body');
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
yield { content: '', isDone: true };
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
const content = parsed.choices[0]?.delta?.content || '';
|
||||
const isDone = parsed.choices[0]?.finish_reason != null;
|
||||
|
||||
yield { content, isDone };
|
||||
} catch {
|
||||
// Skip invalid JSON
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield { content: '', isDone: true };
|
||||
}
|
||||
|
||||
countTokens(messages: LLMMessage[]): number {
|
||||
// Simple estimation: ~4 characters per token on average
|
||||
const totalChars = messages.reduce((sum, msg) => sum + msg.content.length, 0);
|
||||
return Math.ceil(totalChars / 4);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
import 'dotenv/config';
|
||||
import { NestFactory } from '@nestjs/core';
|
||||
import { SwaggerModule, DocumentBuilder } from '@nestjs/swagger';
|
||||
import { AppModule } from './app.module';
|
||||
|
||||
async function bootstrap() {
|
||||
@@ -10,11 +12,40 @@ async function bootstrap() {
|
||||
});
|
||||
|
||||
app.setGlobalPrefix('api');
|
||||
|
||||
// Swagger/OpenAPI setup
|
||||
const config = new DocumentBuilder()
|
||||
.setTitle('DreamChat API')
|
||||
.setDescription('The DreamChat API documentation')
|
||||
.setVersion('1.0.0')
|
||||
.addBearerAuth()
|
||||
.build();
|
||||
|
||||
const document = SwaggerModule.createDocument(app, config);
|
||||
SwaggerModule.setup('api/docs', app, document);
|
||||
|
||||
// Also export the OpenAPI spec as JSON
|
||||
const fs = await import('fs');
|
||||
const path = await import('path');
|
||||
|
||||
// Ensure the output directory exists
|
||||
const outputDir = path.join(process.cwd(), '..', '..', 'openapi');
|
||||
if (!fs.existsSync(outputDir)) {
|
||||
fs.mkdirSync(outputDir, { recursive: true });
|
||||
}
|
||||
|
||||
// Write the spec file
|
||||
fs.writeFileSync(
|
||||
path.join(outputDir, 'openapi.json'),
|
||||
JSON.stringify(document, null, 2),
|
||||
);
|
||||
console.log(`📄 OpenAPI spec written to: ${path.join(outputDir, 'openapi.json')}`);
|
||||
|
||||
const port = process.env.PORT || 3000;
|
||||
await app.listen(port);
|
||||
|
||||
console.log(`🚀 Backend running on: http://localhost:${port}/api`);
|
||||
console.log(`📚 API docs available at: http://localhost:${port}/api/docs`);
|
||||
}
|
||||
|
||||
bootstrap();
|
||||
|
||||
9
apps/backend/src/prisma/prisma.module.ts
Normal file
9
apps/backend/src/prisma/prisma.module.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
import { Module, Global } from '@nestjs/common';
|
||||
import { PrismaService } from './prisma.service';
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
providers: [PrismaService],
|
||||
exports: [PrismaService],
|
||||
})
|
||||
export class PrismaModule {}
|
||||
20
apps/backend/src/prisma/prisma.service.ts
Normal file
20
apps/backend/src/prisma/prisma.service.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { Injectable, OnModuleInit, OnModuleDestroy } from '@nestjs/common';
|
||||
import { PrismaPg } from '@prisma/adapter-pg';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class PrismaService extends PrismaClient implements OnModuleInit, OnModuleDestroy {
|
||||
constructor() {
|
||||
const adapter = new PrismaPg({
|
||||
connectionString: process.env.DATABASE_URL!,
|
||||
});
|
||||
super({ adapter: adapter });
|
||||
}
|
||||
async onModuleInit() {
|
||||
await this.$connect();
|
||||
}
|
||||
|
||||
async onModuleDestroy() {
|
||||
await this.$disconnect();
|
||||
}
|
||||
}
|
||||
27
apps/backend/src/user/dto/update-user.dto.ts
Normal file
27
apps/backend/src/user/dto/update-user.dto.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { IsString, IsEmail, MinLength, IsOptional } from 'class-validator';
|
||||
import { ApiPropertyOptional } from '@nestjs/swagger';
|
||||
|
||||
export class UpdateUserDto {
|
||||
@ApiPropertyOptional({ description: 'New email address', example: 'newemail@example.com' })
|
||||
@IsOptional()
|
||||
@IsEmail()
|
||||
email?: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'New username', example: 'newusername' })
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(3)
|
||||
username?: string;
|
||||
}
|
||||
|
||||
export class UpdatePasswordDto {
|
||||
@ApiPropertyOptional({ description: 'Current password', example: 'oldpassword123' })
|
||||
@IsString()
|
||||
@MinLength(6)
|
||||
currentPassword: string;
|
||||
|
||||
@ApiPropertyOptional({ description: 'New password', example: 'newpassword123' })
|
||||
@IsString()
|
||||
@MinLength(6)
|
||||
newPassword: string;
|
||||
}
|
||||
54
apps/backend/src/user/user.controller.ts
Normal file
54
apps/backend/src/user/user.controller.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
import { Controller, Get, Put, Delete, Body } from '@nestjs/common';
|
||||
import { ApiTags, ApiOperation, ApiResponse, ApiBearerAuth } from '@nestjs/swagger';
|
||||
import { UserService } from './user.service';
|
||||
import { UpdateUserDto, UpdatePasswordDto } from './dto/update-user.dto';
|
||||
import { CurrentUser } from '../common/decorators/current-user.decorator';
|
||||
import { User } from '@prisma/client';
|
||||
|
||||
@ApiTags('users')
|
||||
@ApiBearerAuth()
|
||||
@Controller('users')
|
||||
export class UserController {
|
||||
constructor(private userService: UserService) {}
|
||||
|
||||
@Get('me')
|
||||
@ApiOperation({ summary: 'Get current user profile' })
|
||||
@ApiResponse({ status: 200, description: 'User profile retrieved' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
async getProfile(@CurrentUser('userId') userId: string): Promise<Omit<User, 'passwordHash'>> {
|
||||
return this.userService.findById(userId);
|
||||
}
|
||||
|
||||
@Put('me')
|
||||
@ApiOperation({ summary: 'Update current user profile' })
|
||||
@ApiResponse({ status: 200, description: 'Profile updated' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
@ApiResponse({ status: 409, description: 'Email or username already exists' })
|
||||
async updateProfile(
|
||||
@CurrentUser('userId') userId: string,
|
||||
@Body() updateUserDto: UpdateUserDto,
|
||||
): Promise<Omit<User, 'passwordHash'>> {
|
||||
return this.userService.update(userId, updateUserDto);
|
||||
}
|
||||
|
||||
@Put('me/password')
|
||||
@ApiOperation({ summary: 'Update user password' })
|
||||
@ApiResponse({ status: 200, description: 'Password updated successfully' })
|
||||
@ApiResponse({ status: 401, description: 'Current password incorrect' })
|
||||
async updatePassword(
|
||||
@CurrentUser('userId') userId: string,
|
||||
@Body() updatePasswordDto: UpdatePasswordDto,
|
||||
): Promise<{ message: string }> {
|
||||
await this.userService.updatePassword(userId, updatePasswordDto);
|
||||
return { message: 'Password updated successfully' };
|
||||
}
|
||||
|
||||
@Delete('me')
|
||||
@ApiOperation({ summary: 'Delete user account' })
|
||||
@ApiResponse({ status: 200, description: 'Account deleted' })
|
||||
@ApiResponse({ status: 401, description: 'Unauthorized' })
|
||||
async deleteAccount(@CurrentUser('userId') userId: string): Promise<{ message: string }> {
|
||||
await this.userService.delete(userId);
|
||||
return { message: 'Account deleted successfully' };
|
||||
}
|
||||
}
|
||||
10
apps/backend/src/user/user.module.ts
Normal file
10
apps/backend/src/user/user.module.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { UserService } from './user.service';
|
||||
import { UserController } from './user.controller';
|
||||
|
||||
@Module({
|
||||
providers: [UserService],
|
||||
controllers: [UserController],
|
||||
exports: [UserService],
|
||||
})
|
||||
export class UserModule {}
|
||||
95
apps/backend/src/user/user.service.ts
Normal file
95
apps/backend/src/user/user.service.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
import { Injectable, NotFoundException, ConflictException, UnauthorizedException } from '@nestjs/common';
|
||||
import { PrismaService } from '../prisma/prisma.service';
|
||||
import { UpdateUserDto, UpdatePasswordDto } from './dto/update-user.dto';
|
||||
import * as bcrypt from 'bcrypt';
|
||||
import { User } from '@prisma/client';
|
||||
|
||||
@Injectable()
|
||||
export class UserService {
|
||||
constructor(private prisma: PrismaService) {}
|
||||
|
||||
async findById(id: string): Promise<Omit<User, 'passwordHash'>> {
|
||||
const user = await this.prisma.user.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
throw new NotFoundException('User not found');
|
||||
}
|
||||
|
||||
const { passwordHash, ...userWithoutPassword } = user;
|
||||
return userWithoutPassword;
|
||||
}
|
||||
|
||||
async update(id: string, updateUserDto: UpdateUserDto): Promise<Omit<User, 'passwordHash'>> {
|
||||
const existingUser = await this.prisma.user.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
|
||||
if (!existingUser) {
|
||||
throw new NotFoundException('User not found');
|
||||
}
|
||||
|
||||
if (updateUserDto.email && updateUserDto.email !== existingUser.email) {
|
||||
const emailExists = await this.prisma.user.findUnique({
|
||||
where: { email: updateUserDto.email },
|
||||
});
|
||||
if (emailExists) {
|
||||
throw new ConflictException('Email already in use');
|
||||
}
|
||||
}
|
||||
|
||||
if (updateUserDto.username && updateUserDto.username !== existingUser.username) {
|
||||
const usernameExists = await this.prisma.user.findUnique({
|
||||
where: { username: updateUserDto.username },
|
||||
});
|
||||
if (usernameExists) {
|
||||
throw new ConflictException('Username already in use');
|
||||
}
|
||||
}
|
||||
|
||||
const user = await this.prisma.user.update({
|
||||
where: { id },
|
||||
data: updateUserDto,
|
||||
});
|
||||
|
||||
const { passwordHash, ...userWithoutPassword } = user;
|
||||
return userWithoutPassword;
|
||||
}
|
||||
|
||||
async updatePassword(id: string, updatePasswordDto: UpdatePasswordDto): Promise<void> {
|
||||
const user = await this.prisma.user.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
|
||||
if (!user || !user.passwordHash) {
|
||||
throw new NotFoundException('User not found');
|
||||
}
|
||||
|
||||
const isMatch = await bcrypt.compare(updatePasswordDto.currentPassword, user.passwordHash);
|
||||
if (!isMatch) {
|
||||
throw new UnauthorizedException('Current password is incorrect');
|
||||
}
|
||||
|
||||
const hashedPassword = await bcrypt.hash(updatePasswordDto.newPassword, 10);
|
||||
|
||||
await this.prisma.user.update({
|
||||
where: { id },
|
||||
data: { passwordHash: hashedPassword },
|
||||
});
|
||||
}
|
||||
|
||||
async delete(id: string): Promise<void> {
|
||||
const user = await this.prisma.user.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
|
||||
if (!user) {
|
||||
throw new NotFoundException('User not found');
|
||||
}
|
||||
|
||||
await this.prisma.user.delete({
|
||||
where: { id },
|
||||
});
|
||||
}
|
||||
}
|
||||
32
apps/backend/src/vector/embedding.service.ts
Normal file
32
apps/backend/src/vector/embedding.service.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { IEmbeddingProvider } from './interfaces/embedding-provider.interface';
|
||||
import { LocalEmbeddingProvider } from './providers/local-embedding.provider';
|
||||
|
||||
@Injectable()
|
||||
export class EmbeddingService {
|
||||
private provider: IEmbeddingProvider;
|
||||
|
||||
constructor() {
|
||||
const providerType = process.env.EMBEDDING_PROVIDER || 'local';
|
||||
|
||||
switch (providerType) {
|
||||
case 'local':
|
||||
this.provider = new LocalEmbeddingProvider();
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown embedding provider: ${providerType}`);
|
||||
}
|
||||
}
|
||||
|
||||
async embed(text: string): Promise<number[]> {
|
||||
return this.provider.embed(text);
|
||||
}
|
||||
|
||||
async embedBatch(texts: string[]): Promise<number[][]> {
|
||||
return this.provider.embedBatch(texts);
|
||||
}
|
||||
|
||||
getDimension(): number {
|
||||
return this.provider.getDimension();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
export interface IEmbeddingProvider {
|
||||
embed(text: string): Promise<number[]>;
|
||||
embedBatch(texts: string[]): Promise<number[][]>;
|
||||
getDimension(): number;
|
||||
}
|
||||
122
apps/backend/src/vector/memory.service.ts
Normal file
122
apps/backend/src/vector/memory.service.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { EmbeddingService } from './embedding.service';
|
||||
import { VectorStoreService, SearchResult } from './vector-store.service';
|
||||
import { MemoryType } from '@prisma/client';
|
||||
|
||||
export interface MemoryContext {
|
||||
content: string;
|
||||
metadata: any;
|
||||
similarity: number;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class MemoryService {
|
||||
constructor(
|
||||
private embeddingService: EmbeddingService,
|
||||
private vectorStore: VectorStoreService,
|
||||
) {}
|
||||
|
||||
async addMemory(
|
||||
content: string,
|
||||
memoryType: MemoryType,
|
||||
options: {
|
||||
conversationId?: string;
|
||||
characterId?: string;
|
||||
knowledgeId?: string;
|
||||
metadata?: any;
|
||||
},
|
||||
): Promise<void> {
|
||||
const embedding = await this.embeddingService.embed(content);
|
||||
await this.vectorStore.store(content, embedding, memoryType, options);
|
||||
}
|
||||
|
||||
async retrieveRelevantMemories(
|
||||
query: string,
|
||||
options: {
|
||||
limit?: number;
|
||||
threshold?: number;
|
||||
conversationId?: string;
|
||||
characterId?: string;
|
||||
memoryType?: MemoryType;
|
||||
},
|
||||
): Promise<MemoryContext[]> {
|
||||
const embedding = await this.embeddingService.embed(query);
|
||||
const results = await this.vectorStore.searchSimilar(embedding, options);
|
||||
|
||||
return results.map((result) => ({
|
||||
content: result.content,
|
||||
metadata: result.metadata,
|
||||
similarity: result.similarity,
|
||||
}));
|
||||
}
|
||||
|
||||
async buildContextForConversation(
|
||||
conversationId: string,
|
||||
currentMessage: string,
|
||||
characterId: string,
|
||||
): Promise<string> {
|
||||
// Retrieve recent conversation memories
|
||||
const conversationMemories = await this.retrieveRelevantMemories(
|
||||
currentMessage,
|
||||
{
|
||||
limit: 3,
|
||||
threshold: 0.6,
|
||||
conversationId,
|
||||
memoryType: 'conversation',
|
||||
},
|
||||
);
|
||||
|
||||
// Retrieve character knowledge
|
||||
const characterMemories = await this.retrieveRelevantMemories(
|
||||
currentMessage,
|
||||
{
|
||||
limit: 3,
|
||||
threshold: 0.7,
|
||||
characterId,
|
||||
memoryType: 'character',
|
||||
},
|
||||
);
|
||||
|
||||
const contextParts: string[] = [];
|
||||
|
||||
if (characterMemories.length > 0) {
|
||||
contextParts.push('Relevant character knowledge:');
|
||||
characterMemories.forEach((memory) => {
|
||||
contextParts.push(`- ${memory.content}`);
|
||||
});
|
||||
}
|
||||
|
||||
if (conversationMemories.length > 0) {
|
||||
contextParts.push('\nRelevant conversation history:');
|
||||
conversationMemories.forEach((memory) => {
|
||||
contextParts.push(`- ${memory.content}`);
|
||||
});
|
||||
}
|
||||
|
||||
return contextParts.join('\n');
|
||||
}
|
||||
|
||||
async storeConversationMessage(
|
||||
content: string,
|
||||
conversationId: string,
|
||||
metadata?: any,
|
||||
): Promise<void> {
|
||||
await this.addMemory(content, 'conversation', {
|
||||
conversationId,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
async storeCharacterKnowledge(
|
||||
content: string,
|
||||
characterId: string,
|
||||
knowledgeId: string,
|
||||
metadata?: any,
|
||||
): Promise<void> {
|
||||
await this.addMemory(content, 'character', {
|
||||
characterId,
|
||||
knowledgeId,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
import { Injectable, OnModuleInit } from '@nestjs/common';
|
||||
import { IEmbeddingProvider } from '../interfaces/embedding-provider.interface';
|
||||
import { pipeline, FeatureExtractionPipeline } from '@xenova/transformers';
|
||||
|
||||
@Injectable()
|
||||
export class LocalEmbeddingProvider implements IEmbeddingProvider, OnModuleInit {
|
||||
private extractor: FeatureExtractionPipeline | null = null;
|
||||
private readonly modelName: string;
|
||||
private readonly dimension: number;
|
||||
|
||||
constructor() {
|
||||
this.modelName = process.env.EMBEDDING_MODEL || 'Xenova/all-MiniLM-L6-v2';
|
||||
this.dimension = parseInt(process.env.EMBEDDING_DIMENSION || '384', 10);
|
||||
}
|
||||
|
||||
async onModuleInit() {
|
||||
// Lazy initialization - model will be loaded on first use
|
||||
}
|
||||
|
||||
private async getExtractor(): Promise<FeatureExtractionPipeline> {
|
||||
if (!this.extractor) {
|
||||
this.extractor = await pipeline('feature-extraction', this.modelName, {
|
||||
quantized: false, // Use full precision for better quality
|
||||
});
|
||||
}
|
||||
return this.extractor;
|
||||
}
|
||||
|
||||
async embed(text: string): Promise<number[]> {
|
||||
const extractor = await this.getExtractor();
|
||||
const output = await extractor(text, { pooling: 'mean', normalize: true });
|
||||
return Array.from(output.data as Float32Array);
|
||||
}
|
||||
|
||||
async embedBatch(texts: string[]): Promise<number[][]> {
|
||||
const extractor = await this.getExtractor();
|
||||
const outputs = await Promise.all(
|
||||
texts.map((text) =>
|
||||
extractor(text, { pooling: 'mean', normalize: true }),
|
||||
),
|
||||
);
|
||||
return outputs.map((output) => Array.from(output.data as Float32Array));
|
||||
}
|
||||
|
||||
getDimension(): number {
|
||||
return this.dimension;
|
||||
}
|
||||
}
|
||||
116
apps/backend/src/vector/vector-store.service.ts
Normal file
116
apps/backend/src/vector/vector-store.service.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { PrismaService } from '../prisma/prisma.service';
|
||||
import { MemoryType, VectorMemory } from '@prisma/client';
|
||||
|
||||
export interface SearchResult {
|
||||
id: string;
|
||||
content: string;
|
||||
memoryType: MemoryType;
|
||||
metadata: any;
|
||||
similarity: number;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class VectorStoreService {
|
||||
constructor(private prisma: PrismaService) {}
|
||||
|
||||
async store(
|
||||
content: string,
|
||||
embedding: number[],
|
||||
memoryType: MemoryType,
|
||||
options: {
|
||||
conversationId?: string;
|
||||
characterId?: string;
|
||||
knowledgeId?: string;
|
||||
metadata?: any;
|
||||
},
|
||||
): Promise<VectorMemory> {
|
||||
const vectorString = `[${embedding.join(',')}]`;
|
||||
|
||||
return this.prisma.$queryRaw<VectorMemory[]>`
|
||||
INSERT INTO "VectorMemory" (id, content, embedding, "memoryType", metadata, "conversationId", "characterId", "knowledgeId", "createdAt")
|
||||
VALUES (
|
||||
gen_random_uuid(),
|
||||
${content},
|
||||
${vectorString}::vector,
|
||||
${memoryType},
|
||||
${options.metadata ? JSON.stringify(options.metadata) : null}::jsonb,
|
||||
${options.conversationId || null},
|
||||
${options.characterId || null},
|
||||
${options.knowledgeId || null},
|
||||
NOW()
|
||||
)
|
||||
RETURNING *
|
||||
`.then((results) => results[0]);
|
||||
}
|
||||
|
||||
async searchSimilar(
|
||||
embedding: number[],
|
||||
options: {
|
||||
limit?: number;
|
||||
threshold?: number;
|
||||
conversationId?: string;
|
||||
characterId?: string;
|
||||
memoryType?: MemoryType;
|
||||
},
|
||||
): Promise<SearchResult[]> {
|
||||
const { limit = 5, threshold = 0.7 } = options;
|
||||
const vectorString = `[${embedding.join(',')}]`;
|
||||
|
||||
let whereClause = '';
|
||||
const params: any[] = [vectorString, threshold, limit];
|
||||
let paramIndex = 4;
|
||||
|
||||
if (options.conversationId) {
|
||||
whereClause += ` AND "conversationId" = $${paramIndex}`;
|
||||
params.push(options.conversationId);
|
||||
paramIndex++;
|
||||
}
|
||||
|
||||
if (options.characterId) {
|
||||
whereClause += ` AND "characterId" = $${paramIndex}`;
|
||||
params.push(options.characterId);
|
||||
paramIndex++;
|
||||
}
|
||||
|
||||
if (options.memoryType) {
|
||||
whereClause += ` AND "memoryType" = $${paramIndex}`;
|
||||
params.push(options.memoryType);
|
||||
paramIndex++;
|
||||
}
|
||||
|
||||
const query = `
|
||||
SELECT
|
||||
id,
|
||||
content,
|
||||
"memoryType",
|
||||
metadata,
|
||||
1 - (embedding <=> $1::vector) as similarity
|
||||
FROM "VectorMemory"
|
||||
WHERE 1 - (embedding <=> $1::vector) >= $2
|
||||
${whereClause}
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT $3
|
||||
`;
|
||||
|
||||
return this.prisma.$queryRawUnsafe<SearchResult[]>(query, ...params);
|
||||
}
|
||||
|
||||
async deleteByConversation(conversationId: string): Promise<void> {
|
||||
await this.prisma.vectorMemory.deleteMany({
|
||||
where: { conversationId },
|
||||
});
|
||||
}
|
||||
|
||||
async deleteByCharacter(characterId: string): Promise<void> {
|
||||
await this.prisma.vectorMemory.deleteMany({
|
||||
where: { characterId },
|
||||
});
|
||||
}
|
||||
|
||||
async deleteByKnowledge(knowledgeId: string): Promise<void> {
|
||||
await this.prisma.vectorMemory.deleteMany({
|
||||
where: { knowledgeId },
|
||||
});
|
||||
}
|
||||
}
|
||||
10
apps/backend/src/vector/vector.module.ts
Normal file
10
apps/backend/src/vector/vector.module.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { EmbeddingService } from './embedding.service';
|
||||
import { VectorStoreService } from './vector-store.service';
|
||||
import { MemoryService } from './memory.service';
|
||||
|
||||
@Module({
|
||||
providers: [EmbeddingService, VectorStoreService, MemoryService],
|
||||
exports: [EmbeddingService, VectorStoreService, MemoryService],
|
||||
})
|
||||
export class VectorModule {}
|
||||
Reference in New Issue
Block a user