Skip to content

Commit

Permalink
Merge branch 'development' into vartiukhov/feature/show-parsable-test…
Browse files Browse the repository at this point in the history
…-suite
  • Loading branch information
arksap2002 authored Sep 30, 2024
2 parents 6beec89 + b991154 commit 228a643
Show file tree
Hide file tree
Showing 19 changed files with 457 additions and 144 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ dependencies {
implementation("org.junit.jupiter:junit-jupiter-engine:5.10.0")

// https://mvnrepository.com/artifact/org.jacoco/org.jacoco.core
implementation("org.jacoco:org.jacoco.core:0.8.8")
implementation("org.jacoco:org.jacoco.core:0.8.12")
// https://mvnrepository.com/artifact/com.github.javaparser/javaparser-core
implementation("com.github.javaparser:javaparser-symbol-solver-core:3.24.2")

Expand Down
2 changes: 1 addition & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ publishing {
create<MavenPublication>("maven") {
groupId = group as String
artifactId = "testspark-core"
version = "2.0.5"
version = "3.0.0"
from(components["java"])

artifact(tasks["sourcesJar"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class LLMWithFeedbackCycle(
var executionResult = FeedbackCycleExecutionResult.OK
val compilableTestCases: MutableSet<TestCaseGeneratedByLLM> = mutableSetOf()

// collect imports from all responses
val imports: MutableSet<String> = mutableSetOf()

var generatedTestSuite: TestSuiteGeneratedByLLM? = null

while (!generatedTestsArePassing) {
Expand Down Expand Up @@ -186,6 +189,9 @@ class LLMWithFeedbackCycle(

generatedTestSuite = response.testSuite

// update imports list
imports.addAll(generatedTestSuite.imports)

// Process stopped checking
if (indicator.isCanceled()) {
executionResult = FeedbackCycleExecutionResult.CANCELED
Expand Down Expand Up @@ -267,6 +273,8 @@ class LLMWithFeedbackCycle(

log.info { "Result is compilable" }

generatedTestSuite.imports.addAll(imports)

generatedTestsArePassing = true

recordReport(report, testCases)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ fun executeTestCaseModificationRequest(
errorMonitor: ErrorMonitor = DefaultErrorMonitor(),
): TestSuiteGeneratedByLLM? {
// Update Token information
val prompt = "For this test:\n ```\n $testCase\n ```\nPerform the following task: $task"
val prompt = buildString {
append("For this test:\n ```\n ")
append(testCase)
append("\n```\nGenerate a SINGLE test method. Do not change class and method names.")
append("\nPerform the following task:\n")
append(task)
}

val packageName = getPackageFromTestSuiteCode(testCase, language)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,149 +2,168 @@ package org.jetbrains.research.testspark.core.generation.llm.prompt

import org.jetbrains.research.testspark.core.data.ClassType
import org.jetbrains.research.testspark.core.generation.llm.prompt.configuration.ClassRepresentation
import java.util.EnumMap

/**
* Builds prompts by populating a template with keyword values
* and validates that all mandatory keywords are provided.
*
* @property promptTemplate The template string for the prompt.
*/
class PromptBuilder(private val promptTemplate: String) {
private val insertedKeywordValues: EnumMap<PromptKeyword, String> = EnumMap(PromptKeyword::class.java)

// collect all the keywords present in the prompt template
private val templateKeywords: List<PromptKeyword> = buildList {
for (keyword in PromptKeyword.entries) {
if (promptTemplate.contains(keyword.variable)) {
add(keyword)
}
}
}

internal class PromptBuilder(private var prompt: String) {
private fun isPromptValid(
keyword: PromptKeyword,
prompt: String,
): Boolean {
val keywordText = keyword.text
val isMandatory = keyword.mandatory
/**
* Builds the prompt by populating the template with the inserted values
* and validating that all mandatory keywords were provided.
*
* @return The built prompt.
* @throws IllegalStateException if a mandatory keyword is not present in the template.
*/
fun build(): String {
var populatedPrompt = promptTemplate

// populate the template with the inserted values
for ((keyword, value) in insertedKeywordValues.entries) {
populatedPrompt = populatedPrompt.replace(keyword.variable, value, ignoreCase = false)
}

return (prompt.contains(keywordText) || !isMandatory)
// validate that all mandatory keywords were provided
for (keyword in templateKeywords) {
if (!insertedKeywordValues.contains(keyword) && keyword.mandatory) {
throw IllegalStateException("The prompt must contain ${keyword.name} keyword")
}
}

return populatedPrompt
}

fun insertLanguage(language: String) = apply {
if (isPromptValid(PromptKeyword.LANGUAGE, prompt)) {
val keyword = "\$${PromptKeyword.LANGUAGE.text}"
prompt = prompt.replace(keyword, language, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.LANGUAGE.text}")
/**
* Inserts a keyword and its corresponding value into the prompt template.
* If the keyword is marked as mandatory and not present in the template, an IllegalArgumentException is thrown.
*
* @param keyword The keyword to be inserted.
* @param value The value corresponding to the keyword.
* @throws IllegalArgumentException if a mandatory keyword is not present in the template.
*/
private fun insert(keyword: PromptKeyword, value: String) {
if (!templateKeywords.contains(keyword) && keyword.mandatory) {
throw IllegalArgumentException("Prompt template does not contain mandatory ${keyword.name}")
}
insertedKeywordValues[keyword] = value
}

fun insertLanguage(language: String) = apply {
insert(PromptKeyword.LANGUAGE, language)
}

fun insertName(classDisplayName: String) = apply {
if (isPromptValid(PromptKeyword.NAME, prompt)) {
val keyword = "\$${PromptKeyword.NAME.text}"
prompt = prompt.replace(keyword, classDisplayName, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.NAME.text}")
}
insert(PromptKeyword.NAME, classDisplayName)
}

fun insertTestingPlatform(testingPlatformName: String) = apply {
if (isPromptValid(PromptKeyword.TESTING_PLATFORM, prompt)) {
val keyword = "\$${PromptKeyword.TESTING_PLATFORM.text}"
prompt = prompt.replace(keyword, testingPlatformName, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.TESTING_PLATFORM.text}")
}
insert(PromptKeyword.TESTING_PLATFORM, testingPlatformName)
}

fun insertMockingFramework(mockingFrameworkName: String) = apply {
if (isPromptValid(PromptKeyword.MOCKING_FRAMEWORK, prompt)) {
val keyword = "\$${PromptKeyword.MOCKING_FRAMEWORK.text}"
prompt = prompt.replace(keyword, mockingFrameworkName, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.MOCKING_FRAMEWORK.text}")
}
insert(PromptKeyword.MOCKING_FRAMEWORK, mockingFrameworkName)
}

fun insertCodeUnderTest(classFullText: String, classesToTest: List<ClassRepresentation>) = apply {
if (isPromptValid(PromptKeyword.CODE, prompt)) {
val keyword = "\$${PromptKeyword.CODE.text}"
var fullText = "```\n${classFullText}\n```\n"

for (i in 2..classesToTest.size) {
val subClass = classesToTest[i - 2]
val superClass = classesToTest[i - 1]

fullText += "${subClass.qualifiedName} extends ${superClass.qualifiedName}. " +
"The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n" +
"```\n"
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.CODE.text}")
/**
* Inserts the code under test and its related superclass code into the prompt template.
*
* @param codeFullText The full text of the code under test.
* @param classesToTest The list of ClassRepresentation objects representing the classes involved in the code under test.
* @return The modified prompt builder.
*/
fun insertCodeUnderTest(codeFullText: String, classesToTest: List<ClassRepresentation>) = apply {
val fullText = StringBuilder("```\n${codeFullText}\n```\n")

for (i in 2..classesToTest.size) {
val subClass = classesToTest[i - 2]
val superClass = classesToTest[i - 1]

fullText.append("${subClass.qualifiedName} extends ${superClass.qualifiedName}. ")
.append("The source code of ${superClass.qualifiedName} is:\n```\n${superClass.fullText}\n")
.append("```\n")
}

insert(PromptKeyword.CODE, fullText.toString())
}

fun insertMethodsSignatures(interestingClasses: List<ClassRepresentation>) = apply {
val keyword = "\$${PromptKeyword.METHODS.text}"
val fullText = StringBuilder()

if (isPromptValid(PromptKeyword.METHODS, prompt)) {
var fullText = ""
if (interestingClasses.isNotEmpty()) {
fullText += "Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n"
}
for (interestingClass in interestingClasses) {
if (interestingClass.qualifiedName.startsWith("java") || interestingClass.qualifiedName.startsWith("kotlin")) {
continue
}
if (interestingClasses.isNotEmpty()) {
fullText.append("Here are some information about other methods and classes used by the class under test. Only use them for creating objects, not your own ideas.\n")
}

fullText += "=== methods in ${interestingClass.qualifiedName}:\n"
for (interestingClass in interestingClasses) {
if (interestingClass.qualifiedName.startsWith("java") ||
interestingClass.qualifiedName.startsWith("kotlin")
) {
continue
}

for (method in interestingClass.allMethods) {
// Skip java methods
// TODO: checks for java methods should be done by a caller to make
// this class as abstract and language agnostic as possible.
if (method.containingClassQualifiedName.startsWith("java") ||
method.containingClassQualifiedName.startsWith("kotlin")
) {
continue
}
fullText.append("=== methods in ${interestingClass.qualifiedName}:\n")

fullText += " - ${method.signature}\n"
for (method in interestingClass.allMethods) {
// TODO: checks for java methods should be done by a caller to make
// this class as abstract and language agnostic as possible.
if (method.containingClassQualifiedName.startsWith("java") ||
method.containingClassQualifiedName.startsWith("kotlin")
) {
continue
}

fullText.append(" - ${method.signature}\n")
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.METHODS.text}")
}

insert(PromptKeyword.METHODS, fullText.toString())
}

fun insertPolymorphismRelations(
polymorphismRelations: Map<ClassRepresentation, List<ClassRepresentation>>,
) = apply {
val keyword = "\$${PromptKeyword.POLYMORPHISM.text}"
if (isPromptValid(PromptKeyword.POLYMORPHISM, prompt)) {
// If polymorphismRelations is not empty, we add an instruction to avoid mocking classes if an instantiation of a sub-class is applicable
var fullText = when {
polymorphismRelations.isNotEmpty() -> "Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable"
else -> ""
}
polymorphismRelations.forEach { entry ->
for (currentSubClass in entry.value) {
val subClassTypeName = when (currentSubClass.classType) {
ClassType.INTERFACE -> "an interface implementing"
ClassType.ABSTRACT_CLASS -> "an abstract sub-class of"
ClassType.CLASS -> "a sub-class of"
ClassType.DATA_CLASS -> "a sub data class of"
ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of"
ClassType.OBJECT -> "a sub object of"
}
fullText += "${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n"
val fullText = StringBuilder()

if (polymorphismRelations.isNotEmpty()) {
fullText.append("Use the following polymorphic relationships of classes present in the project. Use them for instantiation when necessary. Do not mock classes if an instantiation of a sub-class is applicable.\n\n")
}

for (entry in polymorphismRelations) {
for (currentSubClass in entry.value) {
val subClassTypeName = when (currentSubClass.classType) {
ClassType.INTERFACE -> "an interface implementing"
ClassType.ABSTRACT_CLASS -> "an abstract sub-class of"
ClassType.CLASS -> "a sub-class of"
ClassType.DATA_CLASS -> "a sub data class of"
ClassType.INLINE_VALUE_CLASS -> "a sub inline value class class of"
ClassType.OBJECT -> "a sub object of"
}
fullText.append("${currentSubClass.qualifiedName} is $subClassTypeName ${entry.key.qualifiedName}.\n")
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.POLYMORPHISM.text}")
}

insert(PromptKeyword.POLYMORPHISM, fullText.toString())
}

fun insertTestSample(testSamplesCode: String) = apply {
val keyword = "\$${PromptKeyword.TEST_SAMPLE.text}"

if (isPromptValid(PromptKeyword.TEST_SAMPLE, prompt)) {
var fullText = testSamplesCode
if (fullText.isNotBlank()) {
fullText = "Use this test samples:\n$fullText\n"
}
prompt = prompt.replace(keyword, fullText, ignoreCase = false)
} else {
throw IllegalStateException("The prompt must contain ${PromptKeyword.TEST_SAMPLE.text}")
var fullText = testSamplesCode
if (fullText.isNotBlank()) {
fullText = "Use this test samples:\n$fullText\n"
}
}

fun build(): String = prompt
insert(PromptKeyword.TEST_SAMPLE, fullText)
}
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
package org.jetbrains.research.testspark.core.generation.llm.prompt

enum class PromptKeyword(val text: String, val description: String, val mandatory: Boolean) {
NAME("NAME", "The name of the code under test (Class name, method name, line number)", true),
CODE("CODE", "The code under test (Class, method, or line)", true),
LANGUAGE("LANGUAGE", "Programming language of the project under test (only Java supported at this point)", true),
enum class PromptKeyword(val description: String, val mandatory: Boolean) {
NAME("The name of the code under test (Class name, method name, line number)", true),
CODE("The code under test (Class, method, or line)", true),
LANGUAGE("Programming language of the project under test (only Java supported at this point)", true),
TESTING_PLATFORM(
"TESTING_PLATFORM",
"Testing platform used in the project (Only JUnit 4 is supported at this point)",
true,
),
MOCKING_FRAMEWORK(
"MOCKING_FRAMEWORK",
"Mock framework that can be used in generated test (Only Mockito is supported at this point)",
false,
),
METHODS("METHODS", "Signature of methods used in the code under tests", false),
POLYMORPHISM("POLYMORPHISM", "Polymorphism relations between classes involved in the code under test", false),
TEST_SAMPLE("TEST_SAMPLE", "Test samples for LLM for test generation", false),
METHODS("Signature of methods used in the code under tests", false),
POLYMORPHISM("Polymorphism relations between classes involved in the code under test", false),
TEST_SAMPLE("Test samples for LLM for test generation", false),
;

fun getOffsets(prompt: String): Pair<Int, Int>? {
val textToHighlight = "\$$text"
val textToHighlight = variable
if (!prompt.contains(textToHighlight)) {
return null
}
Expand All @@ -29,4 +27,13 @@ enum class PromptKeyword(val text: String, val description: String, val mandator
val endOffset = startOffset + textToHighlight.length
return Pair(startOffset, endOffset)
}

/**
* Returns a keyword's text (i.e., its name) with a `$` attached at the start.
*
* Inside a prompt template every keyword is used as `$KEYWORD_NAME`.
* Therefore, this property encapsulates the keyword's representation in a prompt.
*/
val variable: String
get() = "\$${this.name}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package org.jetbrains.research.testspark.core.test.data
* @property testCases The list of test cases in the test suite.
*/
data class TestSuiteGeneratedByLLM(
var imports: Set<String> = emptySet(),
var imports: MutableSet<String> = mutableSetOf(),
var packageName: String = "",
var runWith: String = "",
var otherInfo: String = "",
Expand Down
Loading

0 comments on commit 228a643

Please sign in to comment.