@@ -25,6 +25,7 @@ import com.intellij.task.ProjectTaskManager
2525import com.intellij.task.impl.ModuleBuildTaskImpl
2626import com.intellij.task.impl.ModuleFilesBuildTaskImpl
2727import com.intellij.task.impl.ProjectTaskList
28+ import com.intellij.util.PathsList
2829import com.intellij.util.concurrency.AppExecutorUtil
2930import com.intellij.util.containers.ContainerUtil
3031import com.intellij.util.containers.nullize
@@ -189,8 +190,8 @@ object UtTestsDialogProcessor {
189190 .all()
190191 .thenAsync { compile(project, filesToCompile, springConfigClass) }
191192
192- compilationPromise.onSuccess {
193- if (it .hasErrors() || it .isAborted)
193+ compilationPromise.onSuccess { task ->
194+ if (task .hasErrors() || task .isAborted)
194195 return @onSuccess
195196
196197 (object : Task .Backgroundable (project, " Generate tests" ) {
@@ -215,7 +216,15 @@ object UtTestsDialogProcessor {
215216 updateIndicator(indicator, ProgressRange .SOLVING , " Generate tests: read classes" , 0.0 )
216217
217218 val buildPaths = ReadAction
218- .nonBlocking<BuildPaths ?> { findPaths(model.srcClasses, springConfigClass) }
219+ .nonBlocking<BuildPaths ?> {
220+ findPaths(listOf (findSrcModule(model.srcClasses)) + when (model.projectType) {
221+ Spring -> listOfNotNull(
222+ model.testModule, // needed so we can use `TestContextManager` from `spring-test`
223+ springConfigClass?.let { it.module ? : error(" Module for Spring configuration class not found" ) }
224+ )
225+ else -> emptyList()
226+ })
227+ }
219228 .executeSynchronously()
220229 ? : return
221230
@@ -546,22 +555,18 @@ object UtTestsDialogProcessor {
546555 }
547556 }
548557
549- private fun findPaths (srcClasses : Set <PsiClass >, springConfigPsiClass : PsiClass ? ): BuildPaths ? {
550- val srcModule = findSrcModule(srcClasses)
551- val springConfigModule = springConfigPsiClass?.let { it.module ? : error(" Module for spring configuration class not found" ) }
552-
553- val buildDirs = CompilerPaths .getOutputPaths(setOfNotNull(
554- srcModule, springConfigModule
555- ).toTypedArray())
558+ private fun findPaths (modules : List <Module >): BuildPaths ? {
559+ val buildDirs = CompilerPaths .getOutputPaths(modules.distinct().toTypedArray())
556560 .toList()
557561 .filter { Paths .get(it).exists() }
558562 .nullize() ? : return null
559563
560- val pathsList = OrderEnumerator .orderEntries(srcModule).recursively().pathsList
564+ val pathsList = PathsList ()
561565
562- springConfigModule?.takeIf { it != srcModule }?.let { module ->
563- pathsList.addAll(OrderEnumerator .orderEntries(module).recursively().pathsList.pathList)
564- }
566+ modules
567+ .distinct()
568+ .map { module -> OrderEnumerator .orderEntries(module).recursively().pathsList }
569+ .forEach { pathsList.addAll(it.pathList) }
565570
566571 val (classpath, classpathList) = if (IntelliJApiHelper .isAndroidStudio()) {
567572 // Filter out manifests from classpath.
0 commit comments