From 62c5d72ab1a264b37374812c880d5456ffed14eb Mon Sep 17 00:00:00 2001 From: Arkadii Sapozhnikov Date: Tue, 3 Sep 2024 16:01:17 +0200 Subject: [PATCH] fix incorrect state after indicator cancellation --- .../core/generation/llm/LLMWithFeedbackCycle.kt | 12 ++++++++++++ .../jetbrains/research/testspark/tools/Pipeline.kt | 3 +++ .../jetbrains/research/testspark/tools/ToolUtils.kt | 5 +++-- .../llm/generation/openai/OpenAIRequestManager.kt | 5 +++-- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt index 973b26e7a..7af1d6496 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt @@ -110,6 +110,12 @@ class LLMWithFeedbackCycle( errorMonitor, ) + // Process stopped checking + if (indicator.isCanceled()) { + executionResult = FeedbackCycleExecutionResult.CANCELED + break + } + when (response.errorCode) { ResponseErrorCode.OK -> { log.info { "Test suite generated successfully: ${response.testSuite!!}" } @@ -222,6 +228,12 @@ class LLMWithFeedbackCycle( // saving the compilable test cases compilableTestCases.addAll(testCasesCompilationResult.compilableTestCases) + // Process stopped checking + if (indicator.isCanceled()) { + executionResult = FeedbackCycleExecutionResult.CANCELED + break + } + if (!testCasesCompilationResult.allTestCasesCompilable && !isLastIteration(requestsCount)) { log.info { "Non-compilable test suite: \n${testsPresenter.representTestSuite(generatedTestSuite!!)}" } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt index c44d71cf7..e4c9d208b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/Pipeline.kt @@ -113,6 +113,9 @@ class Pipeline( super.onFinished() testGenerationController.finished() + // Process stopped checking + if (testGenerationController.errorMonitor.hasErrorOccurred()) return + updateEditor(uiContext!!.testGenerationOutput.fileUrl) if (editor != null) { diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 99e1fc4fa..ebf41a043 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -109,11 +109,12 @@ object ToolUtils { * @return true if the process has been stopped, false otherwise */ fun isProcessStopped(errorMonitor: ErrorMonitor, indicator: CustomProgressIndicator): Boolean { - return errorMonitor.hasErrorOccurred() || isProcessCanceled(indicator) + return errorMonitor.hasErrorOccurred() || isProcessCanceled(errorMonitor, indicator) } - fun isProcessCanceled(indicator: CustomProgressIndicator): Boolean { + fun isProcessCanceled(errorMonitor: ErrorMonitor, indicator: CustomProgressIndicator): Boolean { if (indicator.isCanceled()) { + errorMonitor.notifyErrorOccurrence() indicator.stop() return true } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt index 1d9d6a9a4..b71138a98 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt @@ -58,7 +58,7 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { // check response when (val responseCode = connection.responseCode) { HttpURLConnection.HTTP_OK -> { - assembleLlmResponse(request, testsAssembler, indicator) + assembleLlmResponse(request, testsAssembler, indicator, errorMonitor) } HttpURLConnection.HTTP_INTERNAL_ERROR -> { @@ -115,9 +115,10 @@ class OpenAIRequestManager(project: Project) : IJRequestManager(project) { httpRequest: HttpRequests.Request, testsAssembler: TestsAssembler, indicator: CustomProgressIndicator, + errorMonitor: ErrorMonitor, ) { while (true) { - if (ToolUtils.isProcessCanceled(indicator)) return + if (ToolUtils.isProcessCanceled(errorMonitor, indicator)) return var text = httpRequest.reader.readLine()