diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 8b5bc0eb2..813f5a5f7 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -42,7 +42,7 @@ publishing { create("maven") { groupId = group as String artifactId = "testspark-core" - version = "2.0.5" + version = "3.0.0" from(components["java"]) artifact(tasks["sourcesJar"]) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt index 036e87a0d..7040f3e30 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilder.kt @@ -2,149 +2,168 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt import org.jetbrains.research.testspark.core.data.ClassType import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation +import java.util.EnumMap + +/** + * Builds prompts by populating a template with keyword values + * and validates that all mandatory keywords are provided. + * + * @property promptTemplate The template string for the prompt. + */ +class PromptBuilder(private val promptTemplate: String) { + private val insertedKeywordValues: EnumMap = EnumMap(PromptKeyword::class.java) + + // collect all the keywords present in the prompt template + private val templateKeywords: List = buildList { + for (keyword in PromptKeyword.entries) { + if (promptTemplate.contains(keyword.variable)) { + add(keyword) + } + } + } -internal class PromptBuilder(private var prompt: String) { - private fun isPromptValid( - keyword: PromptKeyword, - prompt: String, - ): Boolean { - val keywordText = keyword.text - val isMandatory = keyword.mandatory + /** + * Builds the prompt by populating the template with the inserted values + * and validating that all mandatory keywords were provided. + * + * @return The built prompt. + * @throws IllegalStateException if a mandatory keyword is not present in the template. + */ + fun build(): String { + var populatedPrompt = promptTemplate + + // populate the template with the inserted values + for ((keyword, value) in insertedKeywordValues.entries) { + populatedPrompt = populatedPrompt.replace(keyword.variable, value, ignoreCase = false) + } - return (prompt.contains(keywordText) || !isMandatory) + // validate that all mandatory keywords were provided + for (keyword in templateKeywords) { + if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) { + throw IllegalStateException("The prompt must contain ${keyword.name} keyword") + } + } + + return populatedPrompt } - fun insertLanguage(language: String) = apply { - if (isPromptValid(PromptKeyword.LANGUAGE, prompt)) { - val keyword = "\$${PromptKeyword.LANGUAGE.text}" - prompt = prompt.replace(keyword, language, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}") + /** + * Inserts a keyword and its corresponding value into the prompt template. + * If the keyword is marked as mandatory and not present in the template, an IllegalArgumentException is thrown. + * + * @param keyword The keyword to be inserted. + * @param value The value corresponding to the keyword. + * @throws IllegalArgumentException if a mandatory keyword is not present in the template. + */ + private fun insert(keyword: PromptKeyword, value: String) { + if (!templateKeywords.contains(keyword) && keyword.mandatory) { + throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.name}") } + insertedKeywordValues[keyword] = value + } + + fun insertLanguage(language: String) = apply { + insert(PromptKeyword.LANGUAGE, language) } fun insertName(classDisplayName: String) = apply { - if (isPromptValid(PromptKeyword.NAME, prompt)) { - val keyword = "\$${PromptKeyword.NAME.text}" - prompt = prompt.replace(keyword, classDisplayName, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}") - } + insert(PromptKeyword.NAME, classDisplayName) } fun insertTestingPlatform(testingPlatformName: String) = apply { - if (isPromptValid(PromptKeyword.TESTING_PLATFORM, prompt)) { - val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}" - prompt = prompt.replace(keyword, testingPlatformName, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}") - } + insert(PromptKeyword.TESTING_PLATFORM, testingPlatformName) } fun insertMockingFramework(mockingFrameworkName: String) = apply { - if (isPromptValid(PromptKeyword.MOCKING_FRAMEWORK, prompt)) { - val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}" - prompt = prompt.replace(keyword, mockingFrameworkName, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}") - } + insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName) } - fun insertCodeUnderTest(classFullText: String, classesToTest: List) = apply { - if (isPromptValid(PromptKeyword.CODE, prompt)) { - val keyword = "\$${PromptKeyword.CODE.text}" - var fullText = "```\n${classFullText}\n```\n" - - for (i in 2..classesToTest.size) { - val subClass = classesToTest[i - 2] - val superClass = classesToTest[i - 1] - - fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " + - "The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" + - "```\n" - } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}") + /** + * Inserts the code under test and its related superclass code into the prompt template. + * + * @param codeFullText The full text of the code under test. + * @param classesToTest The list of ClassRepresentation objects representing the classes involved in the code under test. + * @return The modified prompt builder. + */ + fun insertCodeUnderTest(codeFullText: String, classesToTest: List) = apply { + val fullText = StringBuilder("```\n${codeFullText}\n```\n") + + for (i in 2..classesToTest.size) { + val subClass = classesToTest[i - 2] + val superClass = classesToTest[i - 1] + + fullText.append("${subClass.qualifiedName} extends ${superClass.qualifiedName}. ") + .append("The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n") + .append("```\n") } + + insert(PromptKeyword.CODE, fullText.toString()) } fun insertMethodsSignatures(interestingClasses: List) = apply { - val keyword = "\$${PromptKeyword.METHODS.text}" + val fullText = StringBuilder() - if (isPromptValid(PromptKeyword.METHODS, prompt)) { - var fullText = "" - if (interestingClasses.isNotEmpty()) { - fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n" - } - for (interestingClass in interestingClasses) { - if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) { - continue - } + if (interestingClasses.isNotEmpty()) { + fullText.append("Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n") + } - fullText += "=== methods in ${interestingClass.qualifiedName}:\n" + for (interestingClass in interestingClasses) { + if (interestingClass.qualifiedName.startsWith("java") || + interestingClass.qualifiedName.startsWith("kotlin") + ) { + continue + } - for (method in interestingClass.allMethods) { - // Skip java methods - // TODO: checks for java methods should be done by a caller to make - // this class as abstract and language agnostic as possible. - if (method.containingClassQualifiedName.startsWith("java") || - method.containingClassQualifiedName.startsWith("kotlin") - ) { - continue - } + fullText.append("=== methods in ${interestingClass.qualifiedName}:\n") - fullText += " - ${method.signature}\n" + for (method in interestingClass.allMethods) { + // TODO: checks for java methods should be done by a caller to make + // this class as abstract and language agnostic as possible. + if (method.containingClassQualifiedName.startsWith("java") || + method.containingClassQualifiedName.startsWith("kotlin") + ) { + continue } + + fullText.append(" - ${method.signature}\n") } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}") } + + insert(PromptKeyword.METHODS, fullText.toString()) } fun insertPolymorphismRelations( polymorphismRelations: Map>, ) = apply { - val keyword = "\$${PromptKeyword.POLYMORPHISM.text}" - if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) { - // If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable - var fullText = when { - polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable" - else -> "" - } - polymorphismRelations.forEach { entry -> - for (currentSubClass in entry.value) { - val subClassTypeName = when (currentSubClass.classType) { - ClassType.INTERFACE -> "an interface implementing" - ClassType.ABSTRACT_CLASS -> "an abstract sub-class of" - ClassType.CLASS -> "a sub-class of" - ClassType.DATA_CLASS -> "a sub data class of" - ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of" - ClassType.OBJECT -> "a sub object of" - } - fullText += "${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n" + val fullText = StringBuilder() + + if (polymorphismRelations.isNotEmpty()) { + fullText.append("Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable.\n\n") + } + + for (entry in polymorphismRelations) { + for (currentSubClass in entry.value) { + val subClassTypeName = when (currentSubClass.classType) { + ClassType.INTERFACE -> "an interface implementing" + ClassType.ABSTRACT_CLASS -> "an abstract sub-class of" + ClassType.CLASS -> "a sub-class of" + ClassType.DATA_CLASS -> "a sub data class of" + ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of" + ClassType.OBJECT -> "a sub object of" } + fullText.append("${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n") } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}") } + + insert(PromptKeyword.POLYMORPHISM, fullText.toString()) } fun insertTestSample(testSamplesCode: String) = apply { - val keyword = "\$${PromptKeyword.TEST_SAMPLE.text}" - - if (isPromptValid(PromptKeyword.TEST_SAMPLE, prompt)) { - var fullText = testSamplesCode - if (fullText.isNotBlank()) { - fullText = "Use this test samples:\n$fullText\n" - } - prompt = prompt.replace(keyword, fullText, ignoreCase = false) - } else { - throw IllegalStateException("The prompt must contain ${PromptKeyword.TEST_SAMPLE.text}") + var fullText = testSamplesCode + if (fullText.isNotBlank()) { + fullText = "Use this test samples:\n$fullText\n" } - } - fun build(): String = prompt + insert(PromptKeyword.TEST_SAMPLE, fullText) + } } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt index f1d46eece..98a968d23 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptKeyword.kt @@ -1,26 +1,24 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt -enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) { - NAME("NAME", "The name of the code under test (Class name, method name, line number)", true), - CODE("CODE", "The code under test (Class, method, or line)", true), - LANGUAGE("LANGUAGE", "Programming language of the project under test (only Java supported at this point)", true), +enum class PromptKeyword(val description: String, val mandatory: Boolean) { + NAME("The name of the code under test (Class name, method name, line number)", true), + CODE("The code under test (Class, method, or line)", true), + LANGUAGE("Programming language of the project under test (only Java supported at this point)", true), TESTING_PLATFORM( - "TESTING_PLATFORM", "Testing platform used in the project (Only JUnit 4 is supported at this point)", true, ), MOCKING_FRAMEWORK( - "MOCKING_FRAMEWORK", "Mock framework that can be used in generated test (Only Mockito is supported at this point)", false, ), - METHODS("METHODS", "Signature of methods used in the code under tests", false), - POLYMORPHISM("POLYMORPHISM", "Polymorphism relations between classes involved in the code under test", false), - TEST_SAMPLE("TEST_SAMPLE", "Test samples for LLM for test generation", false), + METHODS("Signature of methods used in the code under tests", false), + POLYMORPHISM("Polymorphism relations between classes involved in the code under test", false), + TEST_SAMPLE("Test samples for LLM for test generation", false), ; fun getOffsets(prompt: String): Pair? { - val textToHighlight = "\$$text" + val textToHighlight = variable if (!prompt.contains(textToHighlight)) { return null } @@ -29,4 +27,13 @@ enum class PromptKeyword(val text: String, val description: String, val mandator val endOffset = startOffset + textToHighlight.length return Pair(startOffset, endOffset) } + + /** + * Returns a keyword's text (i.e., its name) with a `$` attached at the start. + * + * Inside a prompt template every keyword is used as `$KEYWORD_NAME`. + * Therefore, this property encapsulates the keyword's representation in a prompt. + */ + val variable: String + get() = "\$${this.name}" } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt index 11d6cd872..91c061938 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/PromptParserHelper.kt @@ -60,8 +60,7 @@ object PromptParserHelper { fun isPromptValid(prompt: String): Boolean { PromptKeyword.entries.forEach { if (it.mandatory) { - val text = "\$${it.text}" - if (!prompt.contains(text)) { + if (!prompt.contains(it.variable)) { return false } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt index d27fa7936..61d28476f 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/llm/LLMSettingsComponent.kt @@ -328,8 +328,8 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent { private fun createButtonPanel(keyword: PromptKeyword, panel: JPanel): JPanel { val buttonPanel = JPanel(FlowLayout(FlowLayout.LEFT)) val editorTextField = panel.getComponent(1) as EditorTextField - val button = JButton("\$${keyword.text}") - button.setForeground(JBColor.ORANGE) + val button = JButton(keyword.variable) + button.foreground = JBColor.ORANGE button.font = Font("Monochrome", Font.BOLD, 12) // add actionListener for button @@ -340,7 +340,7 @@ class LLMSettingsComponent(private val project: Project) : SettingsComponent { val offset = e.caretModel.offset val document = editorTextField.document WriteCommandAction.runWriteCommandAction(e.project) { - document.insertString(offset, "\$${keyword.text}") + document.insertString(offset, keyword.variable) } } } diff --git a/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt new file mode 100644 index 000000000..97094bb75 --- /dev/null +++ b/src/test/kotlin/org/jetbrains/research/testspark/core/generation/llm/prompt/PromptBuilderTest.kt @@ -0,0 +1,263 @@ +package org.jetbrains.research.testspark.core.generation.llm.prompt + +import org.jetbrains.research.testspark.core.data.ClassType +import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation +import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.MethodRepresentation +import org.junit.jupiter.api.Assertions.assertDoesNotThrow +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertContains + +class PromptBuilderTest { + @Test + fun insertLanguage() { + val (keyword, value) = PromptKeyword.LANGUAGE to "Java" + + val prompt = PromptBuilder("Language: ${keyword.variable}") + .insertLanguage(value) + .build() + + assertEquals("Language: Java", prompt) + } + + @Test + fun insertName() { + val (keyword, value) = PromptKeyword.NAME to "MyClass" + + val prompt = PromptBuilder("Name: ${keyword.variable}") + .insertName(value) + .build() + + assertEquals("Name: MyClass", prompt) + } + + @Test + fun insertTestingPlatform() { + val (keyword, value) = PromptKeyword.TESTING_PLATFORM to "JUnit4" + + val prompt = PromptBuilder("Testing platform: ${keyword.variable}") + .insertTestingPlatform(value) + .build() + + assertEquals("Testing platform: JUnit4", prompt) + } + + @Test + fun insertMockingFramework() { + val (keyword, value) = PromptKeyword.MOCKING_FRAMEWORK to "Mockito" + + val prompt = PromptBuilder("Mocking framework: ${keyword.variable}") + .insertMockingFramework(value) + .build() + + assertEquals("Mocking framework: Mockito", prompt) + } + + @Test + fun insertCodeUnderTest() { + val keyword = PromptKeyword.CODE + val code = """ + class MyClass() { + fun f() { println("Hello world!") } + } + """.trimIndent() + + val prompt = PromptBuilder("Code:\n${keyword.variable}") + .insertCodeUnderTest(code, emptyList()) + .build() + + assertContains(prompt, code, message = "Code under test should be inserted into prompt template") + } + + @Test + fun lastInsertionPrevails() { + val keyword = PromptKeyword.LANGUAGE + + val prompt = PromptBuilder("Language: ${keyword.variable}") + .insertLanguage("Java") + .insertLanguage("Kotlin") + .build() + + assertEquals("Language: Kotlin", prompt) + } + + @Test + fun throwsOnMissingMandatoryKeyword() { + val keywords: List = PromptKeyword.entries + + for (keyword in keywords) { + val promptTemplate = "My variable is: ${keyword.variable}" + + if (keyword.mandatory) { + val exception = assertThrows { PromptBuilder(promptTemplate).build() } + assertEquals("The prompt must contain ${keyword.name} keyword", exception.message) + } else { + assertDoesNotThrow { PromptBuilder(promptTemplate).build() } + } + } + } + + @Test + fun testPopulateMultipleVariableEntries() { + val keyword = PromptKeyword.LANGUAGE + val template = """ + Language1: '${keyword.variable}' + Language2: \\${keyword.variable}\\ + Language3: `${keyword.variable}` + """.trimIndent() + + val prompt = PromptBuilder(template) + .insertLanguage("Java") + .build() + + assertEquals( + """ + Language1: 'Java' + Language2: \\Java\\ + Language3: `Java` + """.trimIndent(), + prompt, + ) + } + + @Test + fun testInsertMultipleVariables() { + val template = """ + language: ${PromptKeyword.LANGUAGE.variable} + name: ${PromptKeyword.NAME.variable} + testing platform: ${PromptKeyword.TESTING_PLATFORM.variable} + mocking framework: ${PromptKeyword.MOCKING_FRAMEWORK.variable} + """.trimIndent() + + val prompt = PromptBuilder(template) + .insertLanguage("Java") + .insertName("org.pkg.MyClass") + .insertTestingPlatform("JUnit4") + .insertMockingFramework("Mockito") + .build() + + val expected = """ + language: Java + name: org.pkg.MyClass + testing platform: JUnit4 + mocking framework: Mockito + """.trimIndent() + + assertEquals(expected, prompt) + } + + @Test + fun testThrowsOnNonExistentKeywordInsertion() { + val template = "Language: ${PromptKeyword.LANGUAGE.variable}" + val exception = assertThrows { + PromptBuilder(template).insertName("Name") + } + assertEquals("Prompt template does not contain mandatory ${PromptKeyword.NAME.name}", exception.message) + } + + @Test + fun testInsertMethodsSignatures() { + val keyword = PromptKeyword.METHODS + + val method1 = MethodRepresentation( + signature = "method1():Boolean", + name = "method1", + text = "fun method1(): Boolean { return true }", + containingClassQualifiedName = "MyClass", + ) + val method2 = MethodRepresentation( + signature = "method2():Boolean", + name = "method2", + text = "fun method2(): Boolean { return false }", + containingClassQualifiedName = "MyClass", + ) + val myClass = ClassRepresentation( + qualifiedName = "MyClass", + fullText = """ + class MyClass { + fun method1(): Boolean { return true } + fun method2(): Boolean { return false } + } + """.trimIndent(), + allMethods = listOf(method1, method2), + classType = ClassType.CLASS, + ) + + val interestingClasses = listOf(myClass) + + val expectedMethodsText = """ + Methods: + Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas. + === methods in MyClass: + - method1():Boolean + - method2():Boolean + """.trimIndent() + + val builder = PromptBuilder("Methods:\n${keyword.variable}") + builder.insertMethodsSignatures(interestingClasses) + + val prompt = builder.build() + + assertEquals( + expectedMethodsText + "\n", + prompt, + "Methods' signatures should be inserted into prompt template correctly", + ) + } + + @Test + fun testInsertPolymorphismRelations() { + val myInterface = ClassRepresentation( + qualifiedName = "MyInterface", + fullText = """ + class MyInterface { + } + """.trimIndent(), + allMethods = emptyList(), + classType = ClassType.INTERFACE, + ) + val mySubClass = ClassRepresentation( + qualifiedName = "MySubClass", + fullText = """ + class MySubClass : MyInterface { + } + """.trimIndent(), + allMethods = emptyList(), + classType = ClassType.CLASS, + ) + val polymorphicRelations = mapOf(myInterface to listOf(mySubClass)) + + val prompt = PromptBuilder(PromptKeyword.POLYMORPHISM.variable) + .insertPolymorphismRelations(polymorphicRelations) + .build() + + println("'$prompt'") + + val expected = """ + Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable. + + MySubClass is a sub-class of MyInterface. + """.trimIndent() + + assertEquals(expected + "\n", prompt) + } + + @Test + fun testInsertTestSample() { + val testSamplesCode = """ + @Test + fun testMethod() { + assertEquals(4, 2+2) + } + """.trimIndent() + + val prompt = PromptBuilder(PromptKeyword.TEST_SAMPLE.variable) + .insertTestSample(testSamplesCode) + .build() + + val expected = "Use this test samples:\n" + testSamplesCode + "\n" + + assertEquals(expected, prompt) + } +}