From 0639a9c23d120f744cf67d84b81e71a862a5c840 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Mon, 8 Jul 2024 22:14:57 +0200 Subject: [PATCH 01/16] feat(PromptBuilder): re-implement `PromptBuilder` Make `PromptBuilder` assert that the client inserts all mandatory keywords. --- .../generation/llm/prompt/PromptBuilder.kt | 184 ++++++++++++++---- 1 file changed, 151 insertions(+), 33 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 036e87a0d..763b048d7 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -2,56 +2,116 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt import org.jetbrains.research.testspark.core.data.ClassType import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation +import java.util.EnumMap -internal class PromptBuilder(private var prompt: String) { - private fun isPromptValid( - keyword: PromptKeyword, - prompt: String, - ): Boolean { - val keywordText = keyword.text - val isMandatory = keyword.mandatory - return (prompt.contains(keywordText) || !isMandatory) +/** + * Provides variables from the underlying keyword. + */ +private val PromptKeyword.variable: String + get() = "\$${this.text}" + + +/** + * Populates variables within the prompt template with the provided values. + * Adheres to the Builder Pattern. + */ +class PromptBuilder(private val promptTemplate: String) { + private val insertedKeywordValues: EnumMap = EnumMap(PromptKeyword::class.java) + private val templateKeywords: List + + init { + // collect all the keywords present in the prompt template + templateKeywords = mutableListOf() + + for (keyword in PromptKeyword.entries) { + if (containsPromptKeyword(keyword)) { + templateKeywords.add(keyword) + } + } + } + + private fun containsPromptKeyword(keyword: PromptKeyword): Boolean = promptTemplate.contains(keyword.variable) + + private fun validatePromptKeyword(keyword: PromptKeyword) { + if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) { + throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}") + } + } + + fun build(): String { + var populatedPrompt = promptTemplate + + // populate the template with the inserted values + for ((keyword, value) in insertedKeywordValues.entries) { + populatedPrompt = populatedPrompt.replace(keyword.variable, value, ignoreCase = false) + } + + // validate that all mandatory keywords were provided + for (keyword in templateKeywords) { + validatePromptKeyword(keyword) + } + + return populatedPrompt + } + + private fun insert(keyword: PromptKeyword, value: String) { + insertedKeywordValues[keyword] = value } fun insertLanguage(language: String) = apply { - if (isPromptValid(PromptKeyword.LANGUAGE, prompt)) { + insert(PromptKeyword.LANGUAGE, language) + /*if (requiresPromptKeyword(PromptKeyword.LANGUAGE, promptTemplate)) { val keyword = "\$${PromptKeyword.LANGUAGE.text}" - prompt = prompt.replace(keyword, language, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, language, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}") - } + }*/ } fun insertName(classDisplayName: String) = apply { - if (isPromptValid(PromptKeyword.NAME, prompt)) { + insert(PromptKeyword.NAME, classDisplayName) + /*if (requiresPromptKeyword(PromptKeyword.NAME, promptTemplate)) { val keyword = "\$${PromptKeyword.NAME.text}" - prompt = prompt.replace(keyword, classDisplayName, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, classDisplayName, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}") - } + }*/ } fun insertTestingPlatform(testingPlatformName: String) = apply { - if (isPromptValid(PromptKeyword.TESTING_PLATFORM, prompt)) { + insert(PromptKeyword.TESTING_PLATFORM, testingPlatformName) + /*if (requiresPromptKeyword(PromptKeyword.TESTING_PLATFORM, promptTemplate)) { val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}" - prompt = prompt.replace(keyword, testingPlatformName, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, testingPlatformName, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}") - } + }*/ } fun insertMockingFramework(mockingFrameworkName: String) = apply { - if (isPromptValid(PromptKeyword.MOCKING_FRAMEWORK, prompt)) { + insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName) + /*if (requiresPromptKeyword(PromptKeyword.MOCKING_FRAMEWORK, promptTemplate)) { val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}" - prompt = prompt.replace(keyword, mockingFrameworkName, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, mockingFrameworkName, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}") - } + }*/ } fun insertCodeUnderTest(classFullText: String, classesToTest: List) = apply { - if (isPromptValid(PromptKeyword.CODE, prompt)) { + var fullText = "```\n${classFullText}\n```\n" + for (i in 2..classesToTest.size) { + val subClass = classesToTest[i - 2] + val superClass = classesToTest[i - 1] + + fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + + "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" + + "```\n" + } + + insert(PromptKeyword.CODE, fullText) + /*if (requiresPromptKeyword(PromptKeyword.CODE, promptTemplate)) { val keyword = "\$${PromptKeyword.CODE.text}" var fullText = "```\n${classFullText}\n```\n" @@ -63,16 +123,43 @@ internal class PromptBuilder(private var prompt: String) { "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" + "```\n" } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}") - } + }*/ } fun insertMethodsSignatures(interestingClasses: List) = apply { + var fullText = "" + if (interestingClasses.isNotEmpty()) { + fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" + } + + for (interestingClass in interestingClasses) { + if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { + continue + } + + fullText += "=== methods in ${interestingClass.qualifiedName}:\n" + + for (method in interestingClass.allMethods) { + // Skip java methods + // TODO: checks for java methods should be done by a caller to make + // this class as abstract and language agnostic as possible. + if (method.containingClassQualifiedName.startsWith("java") || + method.containingClassQualifiedName.startsWith("kotlin")) { + continue + } + + fullText += " - ${method.signature}\n" + } + } + + insert(PromptKeyword.METHODS, fullText) + /* val keyword = "\$${PromptKeyword.METHODS.text}" - if (isPromptValid(PromptKeyword.METHODS, prompt)) { + if (requiresPromptKeyword(PromptKeyword.METHODS, promptTemplate)) { var fullText = "" if (interestingClasses.isNotEmpty()) { fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" @@ -97,16 +184,41 @@ internal class PromptBuilder(private var prompt: String) { fullText += " - ${method.signature}\n" } } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}") - } + }*/ } fun insertPolymorphismRelations( polymorphismRelations: Map>, ) = apply { + var fullText = when { + polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable" + else -> "" + } + + polymorphismRelations.forEach { entry -> + for (currentSubClass in entry.value) { + val subClassTypeName = when (currentSubClass.classType) { + ClassType.INTERFACE -> "an interface implementing" + ClassType.ABSTRACT_CLASS -> "an abstract sub-class of" + ClassType.CLASS -> "a sub-class of" + ClassType.DATA_CLASS -> "a sub data class of" + ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of" + ClassType.OBJECT -> "a sub object of" + } + fullText += "${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n" + } + } + + insert(PromptKeyword.POLYMORPHISM, fullText) + + /* val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" + if (requiresPromptKeyword(PromptKeyword.POLYMORPHISM, promptTemplate)) { + var fullText = "" + if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) { // If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable var fullText = when { @@ -126,25 +238,31 @@ internal class PromptBuilder(private var prompt: String) { fullText += "${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n" } } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}") - } + }*/ } fun insertTestSample(testSamplesCode: String) = apply { + var fullText = testSamplesCode + if (fullText.isNotBlank()) { + fullText = "Use this test samples:\n$fullText\n" + } + + insert(PromptKeyword.TEST_SAMPLE, fullText) + + /* val keyword = "\$${PromptKeyword.TEST_SAMPLE.text}" - if (isPromptValid(PromptKeyword.TEST_SAMPLE, prompt)) { + if (requiresPromptKeyword(PromptKeyword.TEST_SAMPLE, promptTemplate)) { var fullText = testSamplesCode if (fullText.isNotBlank()) { fullText = "Use this test samples:\n$fullText\n" } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) + promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) } else { throw IllegalStateException("The prompt must contain ${PromptKeyword.TEST_SAMPLE.text}") - } + }*/ } - - fun build(): String = prompt } From b28a8f10c63ee645794edec84702a6a4654b8029 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Mon, 8 Jul 2024 23:05:08 +0200 Subject: [PATCH 02/16] Remove commented code & move `PromptKeyword`'s prop into its file --- .../generation/llm/prompt/PromptBuilder.kt | 71 +++---------------- .../generation/llm/prompt/PromptKeyword.kt | 7 ++ 2 files changed, 15 insertions(+), 63 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 763b048d7..07ff7d51c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -5,16 +5,9 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration import java.util.EnumMap -/** - * Provides variables from the underlying keyword. - */ -private val PromptKeyword.variable: String - get() = "\$${this.text}" - - /** * Populates variables within the prompt template with the provided values. - * Adheres to the Builder Pattern. + * Adheres to the **Builder Pattern**. */ class PromptBuilder(private val promptTemplate: String) { private val insertedKeywordValues: EnumMap = EnumMap(PromptKeyword::class.java) @@ -31,14 +24,19 @@ class PromptBuilder(private val promptTemplate: String) { } } - private fun containsPromptKeyword(keyword: PromptKeyword): Boolean = promptTemplate.contains(keyword.variable) + private fun containsPromptKeyword(keyword: PromptKeyword): Boolean = + promptTemplate.contains(keyword.variable) private fun validatePromptKeyword(keyword: PromptKeyword) { if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) { - throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}") + throw IllegalStateException("The prompt must contain ${keyword.text}") } } + /** + * Populates the `promptTemplate` with the values for keywords present in `insertedKeywordValues`. + * Validates that all mandatory fields are filled. + */ fun build(): String { var populatedPrompt = promptTemplate @@ -61,42 +59,18 @@ class PromptBuilder(private val promptTemplate: String) { fun insertLanguage(language: String) = apply { insert(PromptKeyword.LANGUAGE, language) - /*if (requiresPromptKeyword(PromptKeyword.LANGUAGE, promptTemplate)) { - val keyword = "\$${PromptKeyword.LANGUAGE.text}" - promptTemplate = promptTemplate.replace(keyword, language, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}") - }*/ } fun insertName(classDisplayName: String) = apply { insert(PromptKeyword.NAME, classDisplayName) - /*if (requiresPromptKeyword(PromptKeyword.NAME, promptTemplate)) { - val keyword = "\$${PromptKeyword.NAME.text}" - promptTemplate = promptTemplate.replace(keyword, classDisplayName, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}") - }*/ } fun insertTestingPlatform(testingPlatformName: String) = apply { insert(PromptKeyword.TESTING_PLATFORM, testingPlatformName) - /*if (requiresPromptKeyword(PromptKeyword.TESTING_PLATFORM, promptTemplate)) { - val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}" - promptTemplate = promptTemplate.replace(keyword, testingPlatformName, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}") - }*/ } fun insertMockingFramework(mockingFrameworkName: String) = apply { insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName) - /*if (requiresPromptKeyword(PromptKeyword.MOCKING_FRAMEWORK, promptTemplate)) { - val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}" - promptTemplate = promptTemplate.replace(keyword, mockingFrameworkName, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}") - }*/ } fun insertCodeUnderTest(classFullText: String, classesToTest: List) = apply { @@ -111,22 +85,6 @@ class PromptBuilder(private val promptTemplate: String) { } insert(PromptKeyword.CODE, fullText) - /*if (requiresPromptKeyword(PromptKeyword.CODE, promptTemplate)) { - val keyword = "\$${PromptKeyword.CODE.text}" - var fullText = "```\n${classFullText}\n```\n" - - for (i in 2..classesToTest.size) { - val subClass = classesToTest[i - 2] - val superClass = classesToTest[i - 1] - - fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + - "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" + - "```\n" - } - promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}") - }*/ } fun insertMethodsSignatures(interestingClasses: List) = apply { @@ -251,18 +209,5 @@ class PromptBuilder(private val promptTemplate: String) { } insert(PromptKeyword.TEST_SAMPLE, fullText) - - /* - val keyword = "\$${PromptKeyword.TEST_SAMPLE.text}" - - if (requiresPromptKeyword(PromptKeyword.TEST_SAMPLE, promptTemplate)) { - var fullText = testSamplesCode - if (fullText.isNotBlank()) { - fullText = "Use this test samples:\n$fullText\n" - } - promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.TEST_SAMPLE.text}") - }*/ } } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt index f1d46eece..209b25d0a 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt @@ -29,4 +29,11 @@ enum class PromptKeyword(val text: String, val description: String, val mandator val endOffset = startOffset + textToHighlight.length return Pair(startOffset, endOffset) } + + // TODO: replace all "\$$" with use of this `PromptKeyword.variable` + /** + * Provides variables from the underlying keyword. + */ + val variable: String + get() = "\$${this.text}" } From 5c5a31b7535ca75228dbcd98804bef78ff10cde5 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Mon, 8 Jul 2024 23:09:33 +0200 Subject: [PATCH 03/16] Write tests for `PromptBuilder` --- .../llm/prompt/PromptBuilderTest.kt | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt new file mode 100644 index 000000000..306d3d369 --- /dev/null +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -0,0 +1,165 @@ +package org.jetbrains.research.testspark.core.generation.llm.prompt + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertContains + +class PromptBuilderTest { + @BeforeEach + fun setUp() { + + } + + @Test + fun insertLanguage() { + val (keyword, value) = PromptKeyword.LANGUAGE to "Java" + + val prompt = PromptBuilder("Language: ${keyword.variable}") + .insertLanguage(value) + .build() + + assertEquals("Language: Java", prompt) + } + + @Test + fun insertName() { + val (keyword, value) = PromptKeyword.NAME to "MyClass" + + val prompt = PromptBuilder("Name: ${keyword.variable}") + .insertName(value) + .build() + + assertEquals("Name: MyClass", prompt) + } + + @Test + fun insertTestingPlatform() { + val (keyword, value) = PromptKeyword.TESTING_PLATFORM to "JUnit4" + + val prompt = PromptBuilder("Testing platform: ${keyword.variable}") + .insertTestingPlatform(value) + .build() + + assertEquals("Testing platform: JUnit4", prompt) + } + + @Test + fun insertMockingFramework() { + val (keyword, value) = PromptKeyword.MOCKING_FRAMEWORK to "Mockito" + + val prompt = PromptBuilder("Mocking framework: ${keyword.variable}") + .insertMockingFramework(value) + .build() + + assertEquals("Mocking framework: Mockito", prompt) + } + + @Test + fun insertCodeUnderTest() { + val keyword = PromptKeyword.CODE + val code = """ + class MyClass() { + fun f() { println("Hello world!") } + } + """.trimIndent() + + val prompt = PromptBuilder("Code:\n${keyword.variable}") + .insertCodeUnderTest(code, emptyList()) + .build() + + assertContains(prompt, code, message = "Code under test should be inserted into prompt template") + } + + @Test + fun lastInsertionPrevails() { + val keyword = PromptKeyword.LANGUAGE + + val prompt = PromptBuilder("Language: ${keyword.variable}") + .insertLanguage("Java") + .insertLanguage("Kotlin") + .build() + + assertEquals("Language: Kotlin", prompt) + } + + @Test + fun throwsOnMissingMandatoryKeyword() { + val keywords: List = PromptKeyword.entries + + for (keyword in keywords) { + val promptTemplate = "My variable is: ${keyword.variable}" + + if (keyword.mandatory) { + val exception = assertThrows { PromptBuilder(promptTemplate).build() } + assertEquals("The prompt must contain ${keyword.text}", exception.message) + } + else { + assertDoesNotThrow { PromptBuilder(promptTemplate).build() } + } + } + } + + @Test + fun testPopulateMultipleVariableEntries() { + val keyword = PromptKeyword.LANGUAGE + val template = """ + Language1: '${keyword.variable}' + Language2: \\${keyword.variable}\\ + Language3: `${keyword.variable}` + """.trimIndent() + + val prompt = PromptBuilder(template) + .insertLanguage("Java") + .build() + + assertEquals(""" + Language1: 'Java' + Language2: \\Java\\ + Language3: `Java` + """.trimIndent(), prompt) + } + + @Test + fun insertMultipleVariables() { + val template = """ + language: ${PromptKeyword.LANGUAGE.variable} + name: ${PromptKeyword.NAME.variable} + testing platform: ${PromptKeyword.TESTING_PLATFORM.variable} + mocking framework: ${PromptKeyword.MOCKING_FRAMEWORK.variable} + """.trimIndent() + + val prompt = PromptBuilder(template) + .insertLanguage("Java") + .insertName("org.pkg.MyClass") + .insertTestingPlatform("JUnit4") + .insertMockingFramework("Mockito") + .build() + + val expected = """ + language: Java + name: org.pkg.MyClass + testing platform: JUnit4 + mocking framework: Mockito + """.trimIndent() + + assertEquals(expected, prompt) + } + + @Test + fun insertMethodsSignatures() { + // TODO: finish + } + + @Test + fun insertPolymorphismRelations() { + // TODO: finish + } + + @Test + fun insertTestSample() { + // TODO: finish + } +} \ No newline at end of file From e67f5371200a62069c3a40077384807a7725b989 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Mon, 8 Jul 2024 23:23:29 +0200 Subject: [PATCH 04/16] Check the inserted keyword present in the prompt template + test it --- .../core/generation/llm/prompt/PromptBuilder.kt | 3 +++ .../core/generation/llm/prompt/PromptBuilderTest.kt | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 07ff7d51c..70d48d284 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -54,6 +54,9 @@ class PromptBuilder(private val promptTemplate: String) { } private fun insert(keyword: PromptKeyword, value: String) { + if (!templateKeywords.contains(keyword)) { + throw IllegalArgumentException("Prompt template does not contain ${keyword.text}") + } insertedKeywordValues[keyword] = value } diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt index 306d3d369..b821a8321 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -148,6 +148,15 @@ class PromptBuilderTest { assertEquals(expected, prompt) } + @Test + fun throwsOnNonExistentKeywordInsertion() { + val template = "Language: ${PromptKeyword.LANGUAGE.variable}" + val exception = assertThrows { + PromptBuilder(template).insertName("Name") + } + assertEquals("Prompt template does not contain ${PromptKeyword.NAME.text}", exception.message) + } + @Test fun insertMethodsSignatures() { // TODO: finish From a56f3ec277d991b148eccde0733b96e0b2350f5f Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Mon, 8 Jul 2024 23:27:47 +0200 Subject: [PATCH 05/16] feat(publish): Publish the updated version of `PromptBuilder` --- core/build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 8b5bc0eb2..799e3638c 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -42,7 +42,7 @@ publishing { create("maven") { groupId = group as String artifactId = "testspark-core" - version = "2.0.5" + version = "2.0.6" from(components["java"]) artifact(tasks["sourcesJar"]) From 7dcf8427885e36b9e22a104972ba13c74c836ff6 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Thu, 11 Jul 2024 09:48:22 +0200 Subject: [PATCH 06/16] refactor: Update check condition in `PromptBuilder.insert` method --- .../testspark/core/generation/llm/prompt/PromptBuilder.kt | 5 +++-- .../core/generation/llm/prompt/PromptBuilderTest.kt | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 70d48d284..369db9a38 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -54,8 +54,8 @@ class PromptBuilder(private val promptTemplate: String) { } private fun insert(keyword: PromptKeyword, value: String) { - if (!templateKeywords.contains(keyword)) { - throw IllegalArgumentException("Prompt template does not contain ${keyword.text}") + if (!templateKeywords.contains(keyword) && keyword.mandatory) { + throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.text}") } insertedKeywordValues[keyword] = value } @@ -76,6 +76,7 @@ class PromptBuilder(private val promptTemplate: String) { insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName) } + // TODO: rename variables (not class but code construct) fun insertCodeUnderTest(classFullText: String, classesToTest: List) = apply { var fullText = "```\n${classFullText}\n```\n" for (i in 2..classesToTest.size) { diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt index b821a8321..702de8ee7 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -154,7 +154,7 @@ class PromptBuilderTest { val exception = assertThrows { PromptBuilder(template).insertName("Name") } - assertEquals("Prompt template does not contain ${PromptKeyword.NAME.text}", exception.message) + assertEquals("Prompt template does not contain mandatory ${PromptKeyword.NAME.text}", exception.message) } @Test From 55c93dcdc6217f791dd426b992c09604ea291e9d Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Thu, 11 Jul 2024 09:50:01 +0200 Subject: [PATCH 07/16] publish: Publish core module with version 2.0.7 This version includes all the changes of `PromptBuilder` on commit 1284ab4b. --- core/build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 799e3638c..aa83a840b 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -42,7 +42,7 @@ publishing { create("maven") { groupId = group as String artifactId = "testspark-core" - version = "2.0.6" + version = "2.0.7" from(components["java"]) artifact(tasks["sourcesJar"]) From 15bd46728c3789354cb330b47254f6fcb5b6c997 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Wed, 25 Sep 2024 15:23:24 +0200 Subject: [PATCH 08/16] refactor: remove commented and redundant code in `PromptBuilder` --- .../generation/llm/prompt/PromptBuilder.kt | 88 ++----------------- .../llm/prompt/PromptBuilderTest.kt | 2 +- 2 files changed, 9 insertions(+), 81 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 369db9a38..bec33f4be 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -11,28 +11,15 @@ import java.util.EnumMap */ class PromptBuilder(private val promptTemplate: String) { private val insertedKeywordValues: EnumMap = EnumMap(PromptKeyword::class.java) - private val templateKeywords: List - - init { - // collect all the keywords present in the prompt template - templateKeywords = mutableListOf() - + // collect all the keywords present in the prompt template + private val templateKeywords: List = buildList { for (keyword in PromptKeyword.entries) { - if (containsPromptKeyword(keyword)) { - templateKeywords.add(keyword) + if (promptTemplate.contains(keyword.variable)) { + add(keyword) } } } - private fun containsPromptKeyword(keyword: PromptKeyword): Boolean = - promptTemplate.contains(keyword.variable) - - private fun validatePromptKeyword(keyword: PromptKeyword) { - if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) { - throw IllegalStateException("The prompt must contain ${keyword.text}") - } - } - /** * Populates the `promptTemplate` with the values for keywords present in `insertedKeywordValues`. * Validates that all mandatory fields are filled. @@ -47,7 +34,9 @@ class PromptBuilder(private val promptTemplate: String) { // validate that all mandatory keywords were provided for (keyword in templateKeywords) { - validatePromptKeyword(keyword) + if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) { + throw IllegalStateException("The prompt must contain ${keyword.text} keyword") + } } return populatedPrompt @@ -118,38 +107,6 @@ class PromptBuilder(private val promptTemplate: String) { } insert(PromptKeyword.METHODS, fullText) - /* - val keyword = "\$${PromptKeyword.METHODS.text}" - - if (requiresPromptKeyword(PromptKeyword.METHODS, promptTemplate)) { - var fullText = "" - if (interestingClasses.isNotEmpty()) { - fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" - } - for (interestingClass in interestingClasses) { - if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { - continue - } - - fullText += "=== methods in ${interestingClass.qualifiedName}:\n" - - for (method in interestingClass.allMethods) { - // Skip java methods - // TODO: checks for java methods should be done by a caller to make - // this class as abstract and language agnostic as possible. - if (method.containingClassQualifiedName.startsWith("java") || - method.containingClassQualifiedName.startsWith("kotlin") - ) { - continue - } - - fullText += " - ${method.signature}\n" - } - } - promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}") - }*/ } fun insertPolymorphismRelations( @@ -160,7 +117,7 @@ class PromptBuilder(private val promptTemplate: String) { else -> "" } - polymorphismRelations.forEach { entry -> + for (entry in polymorphismRelations) { for (currentSubClass in entry.value) { val subClassTypeName = when (currentSubClass.classType) { ClassType.INTERFACE -> "an interface implementing" @@ -175,35 +132,6 @@ class PromptBuilder(private val promptTemplate: String) { } insert(PromptKeyword.POLYMORPHISM, fullText) - - /* - val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" - if (requiresPromptKeyword(PromptKeyword.POLYMORPHISM, promptTemplate)) { - var fullText = "" - - if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) { - // If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable - var fullText = when { - polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable" - else -> "" - } - polymorphismRelations.forEach { entry -> - for (currentSubClass in entry.value) { - val subClassTypeName = when (currentSubClass.classType) { - ClassType.INTERFACE -> "an interface implementing" - ClassType.ABSTRACT_CLASS -> "an abstract sub-class of" - ClassType.CLASS -> "a sub-class of" - ClassType.DATA_CLASS -> "a sub data class of" - ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of" - ClassType.OBJECT -> "a sub object of" - } - fullText += "${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n" - } - } - promptTemplate = promptTemplate.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}") - }*/ } fun insertTestSample(testSamplesCode: String) = apply { diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt index 702de8ee7..49728d2c0 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -94,7 +94,7 @@ class PromptBuilderTest { if (keyword.mandatory) { val exception = assertThrows { PromptBuilder(promptTemplate).build() } - assertEquals("The prompt must contain ${keyword.text}", exception.message) + assertEquals("The prompt must contain ${keyword.text} keyword", exception.message) } else { assertDoesNotThrow { PromptBuilder(promptTemplate).build() } From 1d2f30c899e9d9c56f3d0dc1bd043b8b55b91f2a Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Wed, 25 Sep 2024 15:40:20 +0200 Subject: [PATCH 09/16] refactor: replace manual keyword variable creation with a prop --- .../testspark/core/generation/llm/prompt/PromptKeyword.kt | 8 +++++--- .../research/testspark/helpers/PromptParserHelper.kt | 3 +-- .../testspark/settings/llm/LLMSettingsComponent.kt | 6 +++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt index 209b25d0a..5779d36f6 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt @@ -20,7 +20,7 @@ enum class PromptKeyword(val text: String, val description: String, val mandator ; fun getOffsets(prompt: String): Pair? { - val textToHighlight = "\$$text" + val textToHighlight = variable if (!prompt.contains(textToHighlight)) { return null } @@ -30,9 +30,11 @@ enum class PromptKeyword(val text: String, val description: String, val mandator return Pair(startOffset, endOffset) } - // TODO: replace all "\$$" with use of this `PromptKeyword.variable` /** - * Provides variables from the underlying keyword. + * Returns a keyword's text (i.e., its name) with a `$` attached at the start. + * + * Inside a prompt template every keyword is used as `$KEYWORD_NAME`. + * Therefore, this property encapsulates keyword's prompt representation. */ val variable: String get() = "\$${this.text}" diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt index 11d6cd872..91c061938 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt @@ -60,8 +60,7 @@ object PromptParserHelper { fun isPromptValid(prompt: String): Boolean { PromptKeyword.entries.forEach { if (it.mandatory) { - val text = "\$${it.text}" - if (!prompt.contains(text)) { + if (!prompt.contains(it.variable)) { return false } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt index d27fa7936..61d28476f 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt @@ -328,8 +328,8 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent { private fun createButtonPanel(keyword: PromptKeyword, panel: JPanel): JPanel { val buttonPanel = JPanel(FlowLayout(FlowLayout.LEFT)) val editorTextField = panel.getComponent(1) as EditorTextField - val button = JButton("\$${keyword.text}") - button.setForeground(JBColor.ORANGE) + val button = JButton(keyword.variable) + button.foreground = JBColor.ORANGE button.font = Font("Monochrome", Font.BOLD, 12) // add actionListener for button @@ -340,7 +340,7 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent { val offset = e.caretModel.offset val document = editorTextField.document WriteCommandAction.runWriteCommandAction(e.project) { - document.insertString(offset, "\$${keyword.text}") + document.insertString(offset, keyword.variable) } } } From a9e53c375c5c5abc4e7782ed8714a0eb1ab24dad Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Wed, 25 Sep 2024 15:46:40 +0200 Subject: [PATCH 10/16] refactor: remove `text` prop of `PromptKeyword` since it duplicates `name` --- .../generation/llm/prompt/PromptBuilder.kt | 4 ++-- .../generation/llm/prompt/PromptKeyword.kt | 20 +++++++++---------- .../llm/prompt/PromptBuilderTest.kt | 4 ++-- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index bec33f4be..8f5d00002 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -35,7 +35,7 @@ class PromptBuilder(private val promptTemplate: String) { // validate that all mandatory keywords were provided for (keyword in templateKeywords) { if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) { - throw IllegalStateException("The prompt must contain ${keyword.text} keyword") + throw IllegalStateException("The prompt must contain ${keyword.name} keyword") } } @@ -44,7 +44,7 @@ class PromptBuilder(private val promptTemplate: String) { private fun insert(keyword: PromptKeyword, value: String) { if (!templateKeywords.contains(keyword) && keyword.mandatory) { - throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.text}") + throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.name}") } insertedKeywordValues[keyword] = value } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt index 5779d36f6..98a968d23 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt @@ -1,22 +1,20 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt -enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) { - NAME("NAME", "The name of the code under test (Class name, method name, line number)", true), - CODE("CODE", "The code under test (Class, method, or line)", true), - LANGUAGE("LANGUAGE", "Programming language of the project under test (only Java supported at this point)", true), +enum class PromptKeyword(val description: String, val mandatory: Boolean) { + NAME("The name of the code under test (Class name, method name, line number)", true), + CODE("The code under test (Class, method, or line)", true), + LANGUAGE("Programming language of the project under test (only Java supported at this point)", true), TESTING_PLATFORM( - "TESTING_PLATFORM", "Testing platform used in the project (Only JUnit 4 is supported at this point)", true, ), MOCKING_FRAMEWORK( - "MOCKING_FRAMEWORK", "Mock framework that can be used in generated test (Only Mockito is supported at this point)", false, ), - METHODS("METHODS", "Signature of methods used in the code under tests", false), - POLYMORPHISM("POLYMORPHISM", "Polymorphism relations between classes involved in the code under test", false), - TEST_SAMPLE("TEST_SAMPLE", "Test samples for LLM for test generation", false), + METHODS("Signature of methods used in the code under tests", false), + POLYMORPHISM("Polymorphism relations between classes involved in the code under test", false), + TEST_SAMPLE("Test samples for LLM for test generation", false), ; fun getOffsets(prompt: String): Pair? { @@ -34,8 +32,8 @@ enum class PromptKeyword(val text: String, val description: String, val mandator * Returns a keyword's text (i.e., its name) with a `$` attached at the start. * * Inside a prompt template every keyword is used as `$KEYWORD_NAME`. - * Therefore, this property encapsulates keyword's prompt representation. + * Therefore, this property encapsulates the keyword's representation in a prompt. */ val variable: String - get() = "\$${this.text}" + get() = "\$${this.name}" } diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt index 49728d2c0..374dc9c59 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -94,7 +94,7 @@ class PromptBuilderTest { if (keyword.mandatory) { val exception = assertThrows { PromptBuilder(promptTemplate).build() } - assertEquals("The prompt must contain ${keyword.text} keyword", exception.message) + assertEquals("The prompt must contain ${keyword.name} keyword", exception.message) } else { assertDoesNotThrow { PromptBuilder(promptTemplate).build() } @@ -154,7 +154,7 @@ class PromptBuilderTest { val exception = assertThrows { PromptBuilder(template).insertName("Name") } - assertEquals("Prompt template does not contain mandatory ${PromptKeyword.NAME.text}", exception.message) + assertEquals("Prompt template does not contain mandatory ${PromptKeyword.NAME.name}", exception.message) } @Test From 125c8a1b8c2f9af6ebbab088c7dc6be418fe6788 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Wed, 25 Sep 2024 15:56:47 +0200 Subject: [PATCH 11/16] feat: generate docs via AI for `PromptBuilder` --- .../generation/llm/prompt/PromptBuilder.kt | 40 ++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 8f5d00002..e4c11717d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -5,9 +5,12 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration import java.util.EnumMap + /** - * Populates variables within the prompt template with the provided values. - * Adheres to the **Builder Pattern**. + * Builds prompts by populating a template with keyword values, + * and validates that all mandatory keywords are provided. + * + * @property promptTemplate The template string for the prompt. */ class PromptBuilder(private val promptTemplate: String) { private val insertedKeywordValues: EnumMap = EnumMap(PromptKeyword::class.java) @@ -21,8 +24,11 @@ class PromptBuilder(private val promptTemplate: String) { } /** - * Populates the `promptTemplate` with the values for keywords present in `insertedKeywordValues`. - * Validates that all mandatory fields are filled. + * Builds the prompt by populating the template with the inserted values + * and validating that all mandatory keywords were provided. + * + * @return The built prompt. + * @throws IllegalStateException if a mandatory keyword is not present in the template. */ fun build(): String { var populatedPrompt = promptTemplate @@ -42,6 +48,14 @@ class PromptBuilder(private val promptTemplate: String) { return populatedPrompt } + /** + * Inserts a keyword and its corresponding value into the prompt template. + * If the keyword is marked as mandatory and not present in the template, an IllegalArgumentException is thrown. + * + * @param keyword The keyword to be inserted. + * @param value The value corresponding to the keyword. + * @throws IllegalArgumentException if a mandatory keyword is not present in the template. + */ private fun insert(keyword: PromptKeyword, value: String) { if (!templateKeywords.contains(keyword) && keyword.mandatory) { throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.name}") @@ -65,9 +79,16 @@ class PromptBuilder(private val promptTemplate: String) { insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName) } - // TODO: rename variables (not class but code construct) - fun insertCodeUnderTest(classFullText: String, classesToTest: List) = apply { - var fullText = "```\n${classFullText}\n```\n" + /** + * Inserts the code under test and its related superclass code into the prompt template. + * + * @param codeFullText The full text of the code under test. + * @param classesToTest The list of ClassRepresentation objects representing the classes involved in the code under test. + * @return The modified prompt builder. + */ + fun insertCodeUnderTest(codeFullText: String, classesToTest: List) = apply { + var fullText = "```\n${codeFullText}\n```\n" + for (i in 2..classesToTest.size) { val subClass = classesToTest[i - 2] val superClass = classesToTest[i - 1] @@ -82,19 +103,20 @@ class PromptBuilder(private val promptTemplate: String) { fun insertMethodsSignatures(interestingClasses: List) = apply { var fullText = "" + if (interestingClasses.isNotEmpty()) { fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" } for (interestingClass in interestingClasses) { - if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { + if (interestingClass.qualifiedName.startsWith("java") || + interestingClass.qualifiedName.startsWith("kotlin")) { continue } fullText += "=== methods in ${interestingClass.qualifiedName}:\n" for (method in interestingClass.allMethods) { - // Skip java methods // TODO: checks for java methods should be done by a caller to make // this class as abstract and language agnostic as possible. if (method.containingClassQualifiedName.startsWith("java") || From 4a7579166d9fa5de879d3b557d48815dd64ad0ea Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Wed, 25 Sep 2024 16:55:28 +0200 Subject: [PATCH 12/16] refactor: add newline into the prompt in `PromptBuilder` --- .../testspark/core/generation/llm/prompt/PromptBuilder.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index e4c11717d..cbcd774a3 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -135,7 +135,7 @@ class PromptBuilder(private val promptTemplate: String) { polymorphismRelations: Map>, ) = apply { var fullText = when { - polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable" + polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable.\n\n" else -> "" } From 3756c8a606a49d3b5e47b3fcd876e2fce230e911 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Wed, 25 Sep 2024 16:56:32 +0200 Subject: [PATCH 13/16] feat: cover method info and polymorphism insertions with tests --- .../llm/prompt/PromptBuilderTest.kt | 109 ++++++++++++++++-- 1 file changed, 101 insertions(+), 8 deletions(-) diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt index 374dc9c59..9c5b77c0f 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -1,5 +1,8 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt +import org.jetbrains.research.testspark.core.data.ClassType +import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation +import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.MethodRepresentation import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test @@ -123,7 +126,7 @@ class PromptBuilderTest { } @Test - fun insertMultipleVariables() { + fun testInsertMultipleVariables() { val template = """ language: ${PromptKeyword.LANGUAGE.variable} name: ${PromptKeyword.NAME.variable} @@ -149,7 +152,7 @@ class PromptBuilderTest { } @Test - fun throwsOnNonExistentKeywordInsertion() { + fun testThrowsOnNonExistentKeywordInsertion() { val template = "Language: ${PromptKeyword.LANGUAGE.variable}" val exception = assertThrows { PromptBuilder(template).insertName("Name") @@ -158,17 +161,107 @@ class PromptBuilderTest { } @Test - fun insertMethodsSignatures() { - // TODO: finish + fun testInsertMethodsSignatures() { + val keyword = PromptKeyword.METHODS + + val method1 = MethodRepresentation( + signature = "method1():Boolean", + name = "method1", + text = "fun method1(): Boolean { return true }", + containingClassQualifiedName = "MyClass" + ) + val method2 = MethodRepresentation( + signature = "method2():Boolean", + name = "method2", + text = "fun method2(): Boolean { return false }", + containingClassQualifiedName = "MyClass" + ) + val myClass = ClassRepresentation( + qualifiedName = "MyClass", + fullText = """ + class MyClass { + fun method1(): Boolean { return true } + fun method2(): Boolean { return false } + } + """.trimIndent(), + allMethods = listOf(method1, method2), + classType = ClassType.CLASS, + ) + + val interestingClasses = listOf(myClass) + + val expectedMethodsText = """ + Methods: + Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas. + === methods in MyClass: + - method1():Boolean + - method2():Boolean +""".trimIndent() + + val builder = PromptBuilder("Methods:\n${keyword.variable}") + builder.insertMethodsSignatures(interestingClasses) + + val prompt = builder.build() + + assertEquals( + expectedMethodsText + "\n", + prompt, + "Methods' signatures should be inserted into prompt template correctly" + ) } @Test - fun insertPolymorphismRelations() { - // TODO: finish + fun testInsertPolymorphismRelations() { + val myInterface = ClassRepresentation( + qualifiedName = "MyInterface", + fullText = """ + class MyInterface { + } + """.trimIndent(), + allMethods = emptyList(), + classType = ClassType.INTERFACE, + ) + val mySubClass = ClassRepresentation( + qualifiedName = "MySubClass", + fullText = """ + class MySubClass : MyInterface { + } + """.trimIndent(), + allMethods = emptyList(), + classType = ClassType.CLASS, + ) + val polymorphicRelations = mapOf(myInterface to listOf(mySubClass)) + + val prompt = PromptBuilder(PromptKeyword.POLYMORPHISM.variable) + .insertPolymorphismRelations(polymorphicRelations) + .build() + + println("'$prompt'") + + val expected = """ + Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable. + + MySubClass is a sub-class of MyInterface. + """.trimIndent() + + assertEquals(expected + "\n", prompt) } @Test - fun insertTestSample() { - // TODO: finish + fun testInsertTestSample() { + val testSamplesCode = """ + @Test + fun testMethod() { + assertEquals(4, 2+2) + } + """.trimIndent() + + val prompt = PromptBuilder(PromptKeyword.TEST_SAMPLE.variable) + .insertTestSample(testSamplesCode) + .build() + + val expected = "Use this test samples:\n" + testSamplesCode + "\n" + + assertEquals(expected, prompt) } } \ No newline at end of file From 462c6ddf1195c434d8bafb85b37212f38b3f7149 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Wed, 25 Sep 2024 17:16:05 +0200 Subject: [PATCH 14/16] fix: apply klint --- .../generation/llm/prompt/PromptBuilder.kt | 13 ++++--- .../llm/prompt/PromptBuilderTest.kt | 38 +++++++++---------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index cbcd774a3..065214b15 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -4,8 +4,6 @@ import org.jetbrains.research.testspark.core.data.ClassType import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation import java.util.EnumMap - - /** * Builds prompts by populating a template with keyword values, * and validates that all mandatory keywords are provided. @@ -14,6 +12,7 @@ import java.util.EnumMap */ class PromptBuilder(private val promptTemplate: String) { private val insertedKeywordValues: EnumMap = EnumMap(PromptKeyword::class.java) + // collect all the keywords present in the prompt template private val templateKeywords: List = buildList { for (keyword in PromptKeyword.entries) { @@ -94,8 +93,8 @@ class PromptBuilder(private val promptTemplate: String) { val superClass = classesToTest[i - 1] fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + - "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" + - "```\n" + "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" + + "```\n" } insert(PromptKeyword.CODE, fullText) @@ -110,7 +109,8 @@ class PromptBuilder(private val promptTemplate: String) { for (interestingClass in interestingClasses) { if (interestingClass.qualifiedName.startsWith("java") || - interestingClass.qualifiedName.startsWith("kotlin")) { + interestingClass.qualifiedName.startsWith("kotlin") + ) { continue } @@ -120,7 +120,8 @@ class PromptBuilder(private val promptTemplate: String) { // TODO: checks for java methods should be done by a caller to make // this class as abstract and language agnostic as possible. if (method.containingClassQualifiedName.startsWith("java") || - method.containingClassQualifiedName.startsWith("kotlin")) { + method.containingClassQualifiedName.startsWith("kotlin") + ) { continue } diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt index 9c5b77c0f..97094bb75 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -3,19 +3,13 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt import org.jetbrains.research.testspark.core.data.ClassType import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.MethodRepresentation -import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Assertions.assertDoesNotThrow +import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test - -import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.assertThrows import kotlin.test.assertContains class PromptBuilderTest { - @BeforeEach - fun setUp() { - - } - @Test fun insertLanguage() { val (keyword, value) = PromptKeyword.LANGUAGE to "Java" @@ -98,8 +92,7 @@ class PromptBuilderTest { if (keyword.mandatory) { val exception = assertThrows { PromptBuilder(promptTemplate).build() } assertEquals("The prompt must contain ${keyword.name} keyword", exception.message) - } - else { + } else { assertDoesNotThrow { PromptBuilder(promptTemplate).build() } } } @@ -118,11 +111,14 @@ class PromptBuilderTest { .insertLanguage("Java") .build() - assertEquals(""" + assertEquals( + """ Language1: 'Java' Language2: \\Java\\ Language3: `Java` - """.trimIndent(), prompt) + """.trimIndent(), + prompt, + ) } @Test @@ -168,13 +164,13 @@ class PromptBuilderTest { signature = "method1():Boolean", name = "method1", text = "fun method1(): Boolean { return true }", - containingClassQualifiedName = "MyClass" + containingClassQualifiedName = "MyClass", ) val method2 = MethodRepresentation( signature = "method2():Boolean", name = "method2", text = "fun method2(): Boolean { return false }", - containingClassQualifiedName = "MyClass" + containingClassQualifiedName = "MyClass", ) val myClass = ClassRepresentation( qualifiedName = "MyClass", @@ -183,7 +179,7 @@ class PromptBuilderTest { fun method1(): Boolean { return true } fun method2(): Boolean { return false } } - """.trimIndent(), + """.trimIndent(), allMethods = listOf(method1, method2), classType = ClassType.CLASS, ) @@ -196,7 +192,7 @@ class PromptBuilderTest { === methods in MyClass: - method1():Boolean - method2():Boolean -""".trimIndent() + """.trimIndent() val builder = PromptBuilder("Methods:\n${keyword.variable}") builder.insertMethodsSignatures(interestingClasses) @@ -206,7 +202,7 @@ class PromptBuilderTest { assertEquals( expectedMethodsText + "\n", prompt, - "Methods' signatures should be inserted into prompt template correctly" + "Methods' signatures should be inserted into prompt template correctly", ) } @@ -217,7 +213,7 @@ class PromptBuilderTest { fullText = """ class MyInterface { } - """.trimIndent(), + """.trimIndent(), allMethods = emptyList(), classType = ClassType.INTERFACE, ) @@ -226,7 +222,7 @@ class PromptBuilderTest { fullText = """ class MySubClass : MyInterface { } - """.trimIndent(), + """.trimIndent(), allMethods = emptyList(), classType = ClassType.CLASS, ) @@ -254,7 +250,7 @@ class PromptBuilderTest { fun testMethod() { assertEquals(4, 2+2) } - """.trimIndent() + """.trimIndent() val prompt = PromptBuilder(PromptKeyword.TEST_SAMPLE.variable) .insertTestSample(testSamplesCode) @@ -264,4 +260,4 @@ class PromptBuilderTest { assertEquals(expected, prompt) } -} \ No newline at end of file +} From 41224158593e9987bedf4ea2ab83633c40df527e Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Fri, 27 Sep 2024 13:17:51 +0200 Subject: [PATCH 15/16] fix: use `StringBuilder` for `fullText` assembling --- .../generation/llm/prompt/PromptBuilder.kt | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 065214b15..7040f3e30 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -5,7 +5,7 @@ import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration import java.util.EnumMap /** - * Builds prompts by populating a template with keyword values, + * Builds prompts by populating a template with keyword values * and validates that all mandatory keywords are provided. * * @property promptTemplate The template string for the prompt. @@ -86,25 +86,25 @@ class PromptBuilder(private val promptTemplate: String) { * @return The modified prompt builder. */ fun insertCodeUnderTest(codeFullText: String, classesToTest: List) = apply { - var fullText = "```\n${codeFullText}\n```\n" + val fullText = StringBuilder("```\n${codeFullText}\n```\n") for (i in 2..classesToTest.size) { val subClass = classesToTest[i - 2] val superClass = classesToTest[i - 1] - fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + - "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" + - "```\n" + fullText.append("${subClass.qualifiedName} extends ${superClass.qualifiedName}. ") + .append("The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n") + .append("```\n") } - insert(PromptKeyword.CODE, fullText) + insert(PromptKeyword.CODE, fullText.toString()) } fun insertMethodsSignatures(interestingClasses: List) = apply { - var fullText = "" + val fullText = StringBuilder() if (interestingClasses.isNotEmpty()) { - fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" + fullText.append("Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n") } for (interestingClass in interestingClasses) { @@ -114,7 +114,7 @@ class PromptBuilder(private val promptTemplate: String) { continue } - fullText += "=== methods in ${interestingClass.qualifiedName}:\n" + fullText.append("=== methods in ${interestingClass.qualifiedName}:\n") for (method in interestingClass.allMethods) { // TODO: checks for java methods should be done by a caller to make @@ -125,19 +125,20 @@ class PromptBuilder(private val promptTemplate: String) { continue } - fullText += " - ${method.signature}\n" + fullText.append(" - ${method.signature}\n") } } - insert(PromptKeyword.METHODS, fullText) + insert(PromptKeyword.METHODS, fullText.toString()) } fun insertPolymorphismRelations( polymorphismRelations: Map>, ) = apply { - var fullText = when { - polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable.\n\n" - else -> "" + val fullText = StringBuilder() + + if (polymorphismRelations.isNotEmpty()) { + fullText.append("Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable.\n\n") } for (entry in polymorphismRelations) { @@ -150,11 +151,11 @@ class PromptBuilder(private val promptTemplate: String) { ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of" ClassType.OBJECT -> "a sub object of" } - fullText += "${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n" + fullText.append("${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n") } } - insert(PromptKeyword.POLYMORPHISM, fullText) + insert(PromptKeyword.POLYMORPHISM, fullText.toString()) } fun insertTestSample(testSamplesCode: String) = apply { From 45c7bec62fb087511c837331ec88d0180d083d60 Mon Sep 17 00:00:00 2001 From: Vladislav Artiukhov Date: Fri, 27 Sep 2024 13:22:17 +0200 Subject: [PATCH 16/16] publish: publish `testspark-core` with version `3.0.0` --- core/build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/build.gradle.kts b/core/build.gradle.kts index aa83a840b..813f5a5f7 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -42,7 +42,7 @@ publishing { create("maven") { groupId = group as String artifactId = "testspark-core" - version = "2.0.7" + version = "3.0.0" from(components["java"]) artifact(tasks["sourcesJar"])