diff --git a/examples/semanticsearchservicedemo/src/test/kotlin/com/theagilemonkeys/ellmental/semanticsearchservicedemo/MainKtTest.kt b/examples/semanticsearchservicedemo/src/test/kotlin/com/theagilemonkeys/ellmental/semanticsearchservicedemo/MainKtTest.kt index 70f6cec9..379dd97b 100644 --- a/examples/semanticsearchservicedemo/src/test/kotlin/com/theagilemonkeys/ellmental/semanticsearchservicedemo/MainKtTest.kt +++ b/examples/semanticsearchservicedemo/src/test/kotlin/com/theagilemonkeys/ellmental/semanticsearchservicedemo/MainKtTest.kt @@ -26,7 +26,7 @@ class MainKtTest : StringSpec() { run { val learnUri = "http://localhost:${server.port()}/SemanticSearch/learn" val learnRequest = Request(Method.POST, learnUri) - .body("{\"texts\": [\"$textToLearn\"]}") + .body("{\"items\": [{\"text\": \"$textToLearn\"}]}") val learnResponse = client(learnRequest) learnResponse.bodyString() shouldContain "cannot be blank" learnResponse.status.code shouldBe 500 @@ -50,7 +50,7 @@ class MainKtTest : StringSpec() { // Run the learn operation and ensure it was OK val learnUri = "http://localhost:${server.port()}/SemanticSearch/learn" val learnRequest = Request(Method.POST, learnUri) - .body("{\"texts\": [\"$textToLearn\"]}") + .body("{\"items\": [{\"text\": \"$textToLearn\"}]}") val learnResponse = client(learnRequest) learnResponse.status.code shouldBe 200 diff --git a/modules/semanticsearch/src/main/kotlin/com/theagilemonkeys/ellmental/semanticsearch/SemanticSearch.kt b/modules/semanticsearch/src/main/kotlin/com/theagilemonkeys/ellmental/semanticsearch/SemanticSearch.kt index f423a079..f746e42e 100644 --- a/modules/semanticsearch/src/main/kotlin/com/theagilemonkeys/ellmental/semanticsearch/SemanticSearch.kt +++ b/modules/semanticsearch/src/main/kotlin/com/theagilemonkeys/ellmental/semanticsearch/SemanticSearch.kt @@ -20,13 +20,15 @@ class SemanticSearch { * [EmbeddingsModel] to calculate text embeddings for each piece of text. Then it uses the * [VectorStore] to persist them. * - * @param input A list of texts to be learned. + * @param input A list of LearnInputItems to be learned. Each LearnInputItem may contain the following: + * - text: The text to be learned. + * - metadata: A map of metadata to be associated with the text. */ suspend fun learn(input: LearnInput) = - input.texts.forEach { text -> - check(text.isNotBlank()) { "Text cannot be blank" } - val embedding = embed(text) - val semanticEntry = SemanticEntry(content = text, embedding = embedding) + input.items.forEach { item -> + check(item.text.isNotBlank()) { "Text cannot be blank" } + val embedding = embed(item.text) + val semanticEntry = SemanticEntry(content = item.text, embedding = embedding, metadata = item.metadata) upsert(semanticEntry) } @@ -48,7 +50,15 @@ class SemanticSearch { } @Serializable -data class LearnInput(val texts: List) +data class LearnInputItem ( + val text: String, + val metadata: Map? = emptyMap() +) + +@Serializable +data class LearnInput( + val items: List +) @Serializable data class SearchOutput(val entries: List)