@@ -5,11 +5,13 @@ import { sum } from "$lib/utils/sum";
55import {
66 embeddingEndpoints ,
77 embeddingEndpointSchema ,
8- type EmbeddingEndpoint ,
98} from "$lib/server/embeddingEndpoints/embeddingEndpoints" ;
109import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints" ;
1110
1211import JSON5 from "json5" ;
12+ import type { EmbeddingModel } from "$lib/types/EmbeddingModel" ;
13+ import { collections } from "./database" ;
14+ import { ObjectId } from "mongodb" ;
1315
1416const modelConfig = z . object ( {
1517 /** Used as an identifier in DB */
@@ -42,67 +44,77 @@ const rawEmbeddingModelJSON =
4244
4345const embeddingModelsRaw = z . array ( modelConfig ) . parse ( JSON5 . parse ( rawEmbeddingModelJSON ) ) ;
4446
45- const processEmbeddingModel = async ( m : z . infer < typeof modelConfig > ) => ( {
46- ...m ,
47- id : m . id || m . name ,
47+ const embeddingModels = embeddingModelsRaw . map ( ( rawEmbeddingModel ) => {
48+ const embeddingModel : EmbeddingModel = {
49+ name : rawEmbeddingModel . name ,
50+ description : rawEmbeddingModel . description ,
51+ websiteUrl : rawEmbeddingModel . websiteUrl ,
52+ modelUrl : rawEmbeddingModel . modelUrl ,
53+ chunkCharLength : rawEmbeddingModel . chunkCharLength ,
54+ maxBatchSize : rawEmbeddingModel . maxBatchSize ,
55+ preQuery : rawEmbeddingModel . preQuery ,
56+ prePassage : rawEmbeddingModel . prePassage ,
57+ _id : new ObjectId ( ) ,
58+ createdAt : new Date ( ) ,
59+ updatedAt : new Date ( ) ,
60+ endpoints : rawEmbeddingModel . endpoints ,
61+ } ;
62+
63+ return embeddingModel ;
4864} ) ;
4965
50- const addEndpoint = ( m : Awaited < ReturnType < typeof processEmbeddingModel > > ) => ( {
51- ...m ,
52- getEndpoint : async ( ) : Promise < EmbeddingEndpoint > => {
53- if ( ! m . endpoints ) {
54- return embeddingEndpointTransformersJS ( {
55- type : "transformersjs" ,
56- weight : 1 ,
57- model : m ,
58- } ) ;
59- }
66+ export const getEmbeddingEndpoint = async ( embeddingModel : EmbeddingModel ) => {
67+ if ( ! embeddingModel . endpoints ) {
68+ return embeddingEndpointTransformersJS ( {
69+ type : "transformersjs" ,
70+ weight : 1 ,
71+ model : embeddingModel ,
72+ } ) ;
73+ }
6074
61- const totalWeight = sum ( m . endpoints . map ( ( e ) => e . weight ) ) ;
62-
63- let random = Math . random ( ) * totalWeight ;
64-
65- for ( const endpoint of m . endpoints ) {
66- if ( random < endpoint . weight ) {
67- const args = { ...endpoint , model : m } ;
68-
69- switch ( args . type ) {
70- case "tei" :
71- return embeddingEndpoints . tei ( args ) ;
72- case "transformersjs" :
73- return embeddingEndpoints . transformersjs ( args ) ;
74- case "openai" :
75- return embeddingEndpoints . openai ( args ) ;
76- case "hfapi" :
77- return embeddingEndpoints . hfapi ( args ) ;
78- default :
79- throw new Error ( `Unknown endpoint type: ${ args } ` ) ;
80- }
75+ const totalWeight = sum ( embeddingModel . endpoints . map ( ( e ) => e . weight ) ) ;
76+
77+ let random = Math . random ( ) * totalWeight ;
78+
79+ for ( const endpoint of embeddingModel . endpoints ) {
80+ if ( random < endpoint . weight ) {
81+ const args = { ...endpoint , model : embeddingModel } ;
82+ console . log ( args . type ) ;
83+
84+ switch ( args . type ) {
85+ case " tei" :
86+ return embeddingEndpoints . tei ( args ) ;
87+ case " transformersjs" :
88+ return embeddingEndpoints . transformersjs ( args ) ;
89+ case " openai" :
90+ return embeddingEndpoints . openai ( args ) ;
91+ case " hfapi" :
92+ return embeddingEndpoints . hfapi ( args ) ;
93+ default :
94+ throw new Error ( `Unknown endpoint type: ${ args } ` ) ;
8195 }
82-
83- random -= endpoint . weight ;
8496 }
8597
86- throw new Error ( `Failed to select embedding endpoint` ) ;
87- } ,
88- } ) ;
89-
90- export const embeddingModels = await Promise . all (
91- embeddingModelsRaw . map ( ( e ) => processEmbeddingModel ( e ) . then ( addEndpoint ) )
92- ) ;
93-
94- export const defaultEmbeddingModel = embeddingModels [ 0 ] ;
98+ random -= endpoint . weight ;
99+ }
95100
96- const validateEmbeddingModel = ( _models : EmbeddingBackendModel [ ] , key : "id" | "name" ) => {
97- return z . enum ( [ _models [ 0 ] [ key ] , ..._models . slice ( 1 ) . map ( ( m ) => m [ key ] ) ] ) ;
101+ throw new Error ( `Failed to select embedding endpoint` ) ;
98102} ;
99103
100- export const validateEmbeddingModelById = ( _models : EmbeddingBackendModel [ ] ) => {
101- return validateEmbeddingModel ( _models , "id" ) ;
102- } ;
104+ export const getDefaultEmbeddingModel = async ( ) : Promise < EmbeddingModel > => {
105+ if ( ! embeddingModels [ 0 ] ) {
106+ throw new Error ( `Failed to find default embedding endpoint` ) ;
107+ }
108+
109+ const defaultModel = await collections . embeddingModels . findOne ( {
110+ _id : embeddingModels [ 0 ] . _id ,
111+ } ) ;
103112
104- export const validateEmbeddingModelByName = ( _models : EmbeddingBackendModel [ ] ) => {
105- return validateEmbeddingModel ( _models , "name" ) ;
113+ return defaultModel ? defaultModel : embeddingModels [ 0 ] ;
106114} ;
107115
108- export type EmbeddingBackendModel = typeof defaultEmbeddingModel ;
116+ // to mimic current behaivor with creating embedding models from scratch during server start
117+ export async function pupulateEmbeddingModel ( ) {
118+ await collections . embeddingModels . deleteMany ( { } ) ;
119+ await collections . embeddingModels . insertMany ( embeddingModels ) ;
120+ }
0 commit comments