@@ -2,22 +2,32 @@ package org.utbot.instrumentation.instrumentation.execution
22
33import org.utbot.common.JarUtils
44import com.jetbrains.rd.util.getLogger
5+ import com.jetbrains.rd.util.info
6+ import org.utbot.framework.plugin.api.BeanDefinitionData
7+ import org.utbot.framework.plugin.api.ClassId
8+ import org.utbot.framework.plugin.api.SpringRepositoryId
9+ import org.utbot.framework.plugin.api.util.jClass
510import org.utbot.instrumentation.instrumentation.ArgumentList
611import org.utbot.instrumentation.instrumentation.Instrumentation
712import org.utbot.instrumentation.instrumentation.execution.mock.SpringInstrumentationContext
813import org.utbot.instrumentation.process.HandlerClassesLoader
914import org.utbot.spring.api.context.ContextWrapper
1015import org.utbot.spring.api.repositoryWrapper.RepositoryInteraction
1116import java.security.ProtectionDomain
17+ import java.util.IdentityHashMap
1218
1319/* *
1420 * UtExecutionInstrumentation wrapper that is aware of Spring config and initialises Spring context
1521 */
1622class SpringUtExecutionInstrumentation (
1723 private val delegateInstrumentation : UtExecutionInstrumentation ,
18- private val springConfig : String
24+ private val springConfig : String ,
25+ private val beanDefinitions : List <BeanDefinitionData >,
1926) : Instrumentation<UtConcreteExecutionResult> by delegateInstrumentation {
2027 private lateinit var instrumentationContext: SpringInstrumentationContext
28+
29+ private val relatedBeansCache = mutableMapOf<Class <* >, Set <String >>()
30+
2131 private val springContext: ContextWrapper get() = instrumentationContext.springContext
2232
2333 companion object {
@@ -26,11 +36,16 @@ class SpringUtExecutionInstrumentation(
2636 }
2737
2838 override fun init (pathsToUserClasses : Set <String >) {
29- HandlerClassesLoader .addUrls(listOf (JarUtils .extractJarFileFromResources(
30- jarFileName = SPRING_COMMONS_JAR_FILENAME ,
31- jarResourcePath = " lib/$SPRING_COMMONS_JAR_FILENAME " ,
32- targetDirectoryName = " spring-commons"
33- ).path))
39+ HandlerClassesLoader .addUrls(
40+ listOf (
41+ JarUtils .extractJarFileFromResources(
42+ jarFileName = SPRING_COMMONS_JAR_FILENAME ,
43+ jarResourcePath = " lib/$SPRING_COMMONS_JAR_FILENAME " ,
44+ targetDirectoryName = " spring-commons"
45+ ).path
46+ )
47+ )
48+
3449 instrumentationContext = SpringInstrumentationContext (springConfig)
3550 delegateInstrumentation.instrumentationContext = instrumentationContext
3651 delegateInstrumentation.init (pathsToUserClasses)
@@ -43,39 +58,48 @@ class SpringUtExecutionInstrumentation(
4358 parameters : Any?
4459 ): UtConcreteExecutionResult {
4560 RepositoryInteraction .recordedInteractions.clear()
46- // TODO properly detect which beans need to be reset, right now "orderRepository" and "orderService" are hardcoded
47- val beanNamesToReset = listOf (" orderRepository" , " orderService" )
4861
49- beanNamesToReset.forEach { beanNameToReset ->
50- val beanDefToReset = springContext.getBeanDefinition(beanNameToReset)
51- springContext.removeBeanDefinition(beanNameToReset)
52- springContext.registerBeanDefinition(beanNameToReset, beanDefToReset)
53- }
62+ val beanNamesToReset: Set <String > = getRelevantBeanNames(clazz)
63+ val repositoryDefinitions = springContext.resolveRepositories(beanNamesToReset)
5464
65+ beanNamesToReset.forEach { beanName -> springContext.resetBean(beanName) }
5566 val jdbcTemplate = getBean(" jdbcTemplate" )
56- // TODO properly detect which repositories need to be cleared, right now "orders" is hardcoded
57- val sql = " TRUNCATE TABLE orders"
58- jdbcTemplate::class .java
59- .getMethod(" execute" , sql::class .java)
60- .invoke(jdbcTemplate, sql)
61- val sql2 = " ALTER TABLE orders ALTER COLUMN id RESTART WITH 1"
62- jdbcTemplate::class .java
63- .getMethod(" execute" , sql::class .java)
64- .invoke(jdbcTemplate, sql2)
67+
68+ for (repositoryDefinition in repositoryDefinitions) {
69+ val truncateTableCommand = " TRUNCATE TABLE ${repositoryDefinition.tableName} "
70+ jdbcTemplate::class .java
71+ .getMethod(" execute" , truncateTableCommand::class .java)
72+ .invoke(jdbcTemplate, truncateTableCommand)
73+
74+ val restartIdCommand = " ALTER TABLE ${repositoryDefinition.tableName} ALTER COLUMN id RESTART WITH 1"
75+ jdbcTemplate::class .java
76+ .getMethod(" execute" , restartIdCommand::class .java)
77+ .invoke(jdbcTemplate, restartIdCommand)
78+ }
6579
6680 return delegateInstrumentation.invoke(clazz, methodSignature, arguments, parameters)
6781 }
6882
83+ private fun getRelevantBeanNames (clazz : Class <* >): Set <String > = relatedBeansCache.getOrPut(clazz) {
84+ beanDefinitions
85+ .filter { it.beanTypeFqn == clazz.name }
86+ .flatMap { springContext.getDependenciesForBean(it.beanName) }
87+ .toSet()
88+ .also { logger.info { " Detected relevant beans for class ${clazz.name} : $it " } }
89+ }
90+
6991 fun getBean (beanName : String ): Any = springContext.getBean(beanName)
7092
71- fun saveToRepository (repository : Any , entity : Any ) {
72- // ignore repository interactions done during repository fill up
73- val savedRecordedRepositoryResponses = RepositoryInteraction .recordedInteractions.toList()
74- repository::class .java
75- .getMethod(" save" , Any ::class .java)
76- .invoke(repository, entity)
77- RepositoryInteraction .recordedInteractions.clear()
78- RepositoryInteraction .recordedInteractions.addAll(savedRecordedRepositoryResponses)
93+ fun getRepositoryDescriptions (classId : ClassId ): Set <SpringRepositoryId > {
94+ val relevantBeanNames = getRelevantBeanNames(classId.jClass)
95+ val repositoryDescriptions = springContext.resolveRepositories(relevantBeanNames.toSet())
96+ return repositoryDescriptions.map { repositoryDescription ->
97+ SpringRepositoryId (
98+ repositoryDescription.beanName,
99+ ClassId (repositoryDescription.repositoryName),
100+ ClassId (repositoryDescription.entityName),
101+ )
102+ }.toSet()
79103 }
80104
81105 override fun transform (
@@ -85,14 +109,15 @@ class SpringUtExecutionInstrumentation(
85109 protectionDomain : ProtectionDomain ,
86110 classfileBuffer : ByteArray
87111 ): ByteArray? =
88- // TODO automatically detect which libraries we don't want to transform (by total transformation time)
89- // transforming Spring takes too long
112+ // TODO: automatically detect which libraries we don't want to transform (by total transformation time)
90113 if (listOf (
91114 " org/springframework" ,
92115 " com/fasterxml" ,
93116 " org/hibernate" ,
94117 " org/apache" ,
95- " org/h2"
118+ " org/h2" ,
119+ " javax/" ,
120+ " ch/qos" ,
96121 ).any { className.startsWith(it) }
97122 ) {
98123 null
0 commit comments