Skip to content

Commit

Permalink
JVM IR: Support sealed inline class methods
Browse files Browse the repository at this point in the history
 #KT-27576: Fixed
  • Loading branch information
ilmirus committed Dec 18, 2022
1 parent dd99e9f commit ab3e9a2
Show file tree
Hide file tree
Showing 10 changed files with 1,060 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,11 @@ internal class BridgeLowering(val context: JvmBackendContext) : FileLoweringPass
// Generate common bridges
val generated = mutableMapOf<Method, Bridge>()

for (override in irFunction.allOverridden()) {
// Do not generate additional bridges for functions, inherited from sealed inline classes in addition to functions from interfaces
val trulyOverridden = irFunction.allOverridden().filterNot {
it.origin == JvmLoweredDeclarationOrigin.STATIC_INLINE_CLASS_REPLACEMENT
}
for (override in trulyOverridden) {
if (override.isFakeOverride) continue

val signature = override.jvmMethod
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ internal class JvmMultiFieldValueClassLowering(
copyAttributes(source)
}

override fun createBridgeBody(source: IrSimpleFunction, target: IrSimpleFunction, original: IrFunction, inverted: Boolean) {
override fun createBridgeBody(source: IrSimpleFunction, target: IrSimpleFunction, returnBoxedSealedInlineClass: Boolean) {
allScopes.push(createScope(source))
source.body = context.createJvmIrBuilder(source.symbol).run {
val sourceExplicitParameters = source.explicitParameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ internal abstract class JvmValueClassAbstractLowering(
// the latter is.
//
// This is a potential problem for bridge generation, where we have to ensure that the overridden
// symbols are always up to date. Right now they might not be since we lower each file independently
// symbols are always up-to-date. Right now they might not be since we lower each file independently
// and since deserialized declarations are not mangled at all.
if (function is IrSimpleFunction) {
if (function is IrSimpleFunction && (function.parent as? IrClass)?.isChildOfSealedInlineClass() != true) {
function.overriddenSymbols = replacements.replaceOverriddenSymbols(function)
}
return null
Expand All @@ -81,9 +81,9 @@ internal abstract class JvmValueClassAbstractLowering(
}

private fun transformFlattenedConstructor(function: IrConstructor, replacement: IrConstructor): List<IrDeclaration> {
replacement.valueParameters.forEach {
visitParameter(it)
it.defaultValue?.patchDeclarationParents(replacement)
for (parameter in replacement.valueParameters) {
visitParameter(parameter)
parameter.defaultValue?.patchDeclarationParents(replacement)
}
allScopes.push(createScope(function))
replacement.body = function.body?.transform(this, null)?.patchDeclarationParents(replacement)
Expand All @@ -108,10 +108,10 @@ internal abstract class JvmValueClassAbstractLowering(
return declaration
}

private fun transformSimpleFunctionFlat(function: IrSimpleFunction, replacement: IrSimpleFunction): List<IrDeclaration> {
replacement.valueParameters.forEach {
visitParameter(it)
it.defaultValue?.patchDeclarationParents(replacement)
protected open fun transformSimpleFunctionFlat(function: IrSimpleFunction, replacement: IrSimpleFunction): List<IrDeclaration> {
for (parameter in replacement.valueParameters) {
visitParameter(parameter)
parameter.defaultValue?.patchDeclarationParents(replacement)
}
allScopes.push(createScope(replacement))
replacement.body = function.body?.transform(this, null)?.patchDeclarationParents(replacement)
Expand Down Expand Up @@ -217,9 +217,9 @@ internal abstract class JvmValueClassAbstractLowering(
// Replace the function body with a wrapper
if (bridgeFunction.isFakeOverride && bridgeFunction.parentAsClass.isSpecificLoweringLogicApplicable()) {
// Fake overrides redirect from the replacement to the original function, which is in turn replaced during interfacePhase.
createBridgeBody(replacement, bridgeFunction, function, true)
createBridgeBody(replacement, bridgeFunction)
} else {
createBridgeBody(bridgeFunction, replacement, function, false)
createBridgeBody(bridgeFunction, replacement)
}
return bridgeFunction
}
Expand All @@ -240,8 +240,7 @@ internal abstract class JvmValueClassAbstractLowering(
// visibility rules for bridge methods.
abstract fun createBridgeDeclaration(source: IrSimpleFunction, replacement: IrSimpleFunction, mangledName: Name): IrSimpleFunction

protected abstract fun createBridgeBody(source: IrSimpleFunction, target: IrSimpleFunction, original: IrFunction, inverted: Boolean)

protected abstract fun createBridgeBody(source: IrSimpleFunction, target: IrSimpleFunction, returnBoxedSealedInlineClass: Boolean = false)

// Functions for common lowering dispatching
private inner class NeedsToVisit : IrElementVisitor<Boolean, Nothing?> {
Expand Down Expand Up @@ -307,78 +306,5 @@ internal abstract class JvmValueClassAbstractLowering(
internal fun needsToVisitReturn(expression: IrReturn): Boolean = expression.accept(NeedsToVisit(), null)
internal abstract fun visitClassNewDeclarationsWhenParallel(declaration: IrDeclaration)

// forbid other overrides without modifying dispatcher file JvmValueClassLoweringDispatcher.kt

final override fun visitModuleFragment(declaration: IrModuleFragment): IrModuleFragment = super.visitModuleFragment(declaration)
final override fun visitPackageFragment(declaration: IrPackageFragment): IrPackageFragment = super.visitPackageFragment(declaration)
final override fun visitExternalPackageFragment(declaration: IrExternalPackageFragment): IrExternalPackageFragment =
super.visitExternalPackageFragment(declaration)

final override fun visitDeclaration(declaration: IrDeclarationBase): IrStatement = super.visitDeclaration(declaration)
final override fun visitSimpleFunction(declaration: IrSimpleFunction) = super.visitSimpleFunction(declaration)
final override fun visitConstructor(declaration: IrConstructor) = super.visitConstructor(declaration)
final override fun visitLocalDelegatedProperty(declaration: IrLocalDelegatedProperty) = super.visitLocalDelegatedProperty(declaration)
final override fun visitEnumEntry(declaration: IrEnumEntry) = super.visitEnumEntry(declaration)
final override fun visitTypeParameter(declaration: IrTypeParameter) = super.visitTypeParameter(declaration)
final override fun visitTypeAlias(declaration: IrTypeAlias) = super.visitTypeAlias(declaration)
final override fun visitBody(body: IrBody): IrBody = super.visitBody(body)
final override fun visitExpressionBody(body: IrExpressionBody) = super.visitExpressionBody(body)
final override fun visitSyntheticBody(body: IrSyntheticBody) = super.visitSyntheticBody(body)
final override fun visitSuspendableExpression(expression: IrSuspendableExpression) = super.visitSuspendableExpression(expression)
final override fun visitSuspensionPoint(expression: IrSuspensionPoint) = super.visitSuspensionPoint(expression)
final override fun visitExpression(expression: IrExpression): IrExpression = super.visitExpression(expression)
final override fun visitConst(expression: IrConst<*>) = super.visitConst(expression)
final override fun visitConstantValue(expression: IrConstantValue): IrConstantValue = super.visitConstantValue(expression)
final override fun visitConstantObject(expression: IrConstantObject) = super.visitConstantObject(expression)
final override fun visitConstantPrimitive(expression: IrConstantPrimitive) = super.visitConstantPrimitive(expression)
final override fun visitConstantArray(expression: IrConstantArray) = super.visitConstantArray(expression)
final override fun visitVararg(expression: IrVararg) = super.visitVararg(expression)
final override fun visitSpreadElement(spread: IrSpreadElement): IrSpreadElement = super.visitSpreadElement(spread)
final override fun visitBlock(expression: IrBlock) = super.visitBlock(expression)
final override fun visitComposite(expression: IrComposite) = super.visitComposite(expression)
final override fun visitDeclarationReference(expression: IrDeclarationReference) = super.visitDeclarationReference(expression)
final override fun visitSingletonReference(expression: IrGetSingletonValue) = super.visitSingletonReference(expression)
final override fun visitGetObjectValue(expression: IrGetObjectValue) = super.visitGetObjectValue(expression)
final override fun visitGetEnumValue(expression: IrGetEnumValue) = super.visitGetEnumValue(expression)
final override fun visitValueAccess(expression: IrValueAccessExpression) = super.visitValueAccess(expression)
final override fun visitFieldAccess(expression: IrFieldAccessExpression) = super.visitFieldAccess(expression)
final override fun visitMemberAccess(expression: IrMemberAccessExpression<*>) = super.visitMemberAccess(expression)
final override fun visitConstructorCall(expression: IrConstructorCall) = super.visitConstructorCall(expression)
final override fun visitDelegatingConstructorCall(expression: IrDelegatingConstructorCall) =
super.visitDelegatingConstructorCall(expression)

final override fun visitEnumConstructorCall(expression: IrEnumConstructorCall) = super.visitEnumConstructorCall(expression)
final override fun visitGetClass(expression: IrGetClass) = super.visitGetClass(expression)
final override fun visitCallableReference(expression: IrCallableReference<*>) = super.visitCallableReference(expression)
final override fun visitPropertyReference(expression: IrPropertyReference) = super.visitPropertyReference(expression)
final override fun visitLocalDelegatedPropertyReference(expression: IrLocalDelegatedPropertyReference) =
super.visitLocalDelegatedPropertyReference(expression)

final override fun visitRawFunctionReference(expression: IrRawFunctionReference) = super.visitRawFunctionReference(expression)
final override fun visitFunctionExpression(expression: IrFunctionExpression) = super.visitFunctionExpression(expression)
final override fun visitClassReference(expression: IrClassReference) = super.visitClassReference(expression)
final override fun visitInstanceInitializerCall(expression: IrInstanceInitializerCall) = super.visitInstanceInitializerCall(expression)
final override fun visitWhen(expression: IrWhen) = super.visitWhen(expression)
final override fun visitBranch(branch: IrBranch): IrBranch = super.visitBranch(branch)
final override fun visitElseBranch(branch: IrElseBranch): IrElseBranch = super.visitElseBranch(branch)
final override fun visitLoop(loop: IrLoop) = super.visitLoop(loop)
final override fun visitWhileLoop(loop: IrWhileLoop) = super.visitWhileLoop(loop)
final override fun visitDoWhileLoop(loop: IrDoWhileLoop) = super.visitDoWhileLoop(loop)
final override fun visitTry(aTry: IrTry) = super.visitTry(aTry)
final override fun visitCatch(aCatch: IrCatch): IrCatch = super.visitCatch(aCatch)
final override fun visitBreakContinue(jump: IrBreakContinue) = super.visitBreakContinue(jump)
final override fun visitBreak(jump: IrBreak) = super.visitBreak(jump)
final override fun visitContinue(jump: IrContinue) = super.visitContinue(jump)
final override fun visitThrow(expression: IrThrow) = super.visitThrow(expression)
final override fun visitDynamicExpression(expression: IrDynamicExpression) = super.visitDynamicExpression(expression)
final override fun visitDynamicOperatorExpression(expression: IrDynamicOperatorExpression) =
super.visitDynamicOperatorExpression(expression)

final override fun visitDynamicMemberExpression(expression: IrDynamicMemberExpression) = super.visitDynamicMemberExpression(expression)
final override fun visitErrorDeclaration(declaration: IrErrorDeclaration) = super.visitErrorDeclaration(declaration)
final override fun visitErrorExpression(expression: IrErrorExpression) = super.visitErrorExpression(expression)
final override fun visitErrorCallExpression(expression: IrErrorCallExpression) = super.visitErrorCallExpression(expression)


abstract val IrType.needsHandling: Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,15 @@ class JvmSymbols(
}
val illegalArgumentExceptionCtorString = illegalArgumentException.constructors.single()

val illegalStateException = createClass(FqName("java.lang.IllegalStateException")) { irClass ->
irClass.addConstructor {
name = Name.special("<init>")
}.apply {
addValueParameter("message", irBuiltIns.stringType)
}
}
val illegalStateExceptionCtorString = illegalStateException.constructors.single()

val classCastException = createClass(FqName("java.lang.ClassCastException")) { irClass ->
irClass.addConstructor {
name = Name.special("<init>")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import org.jetbrains.kotlin.ir.IrBuiltIns
import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.symbols.impl.IrValueParameterSymbolImpl
import org.jetbrains.kotlin.ir.types.defaultType
import org.jetbrains.kotlin.ir.types.impl.IrSimpleTypeImpl
import org.jetbrains.kotlin.ir.types.impl.IrStarProjectionImpl
import org.jetbrains.kotlin.ir.util.*
Expand Down Expand Up @@ -175,6 +177,44 @@ class MemoizedInlineClassReplacements(
}
}

/**
* For method in children of sealed inline classes we generate method in the top.
*/
val getSealedInlineClassChildFunctionInTop: (Pair<IrClass, SimpleFunctionWithoutReceiver>) -> IrSimpleFunction =
storageManager.createMemoizedFunction { (top, method) ->
require(top.isSealedInline) {
"Expected method in sealed inline class child"
}
irFactory.buildFun {
name = Name.identifier(InlineClassAbi.functionNameBase(method.function))
origin = JvmLoweredDeclarationOrigin.GENERATED_SEALED_INLINE_CLASS_METHOD
returnType = method.function.returnType
}.apply {
parent = top
copyTypeParameters(method.function.typeParameters)

copyPropertyIfNeeded(method.function)

val substitutionMap = method.function.typeParameters.map { it.symbol }.zip(typeParameters.map { it.defaultType }).toMap()
// Replace dispatch parameter from child to top
dispatchReceiverParameter = factory.createValueParameter(
startOffset, endOffset, origin,
IrValueParameterSymbolImpl(),
name, -1,
top.defaultType.substitute(substitutionMap),
null, isCrossinline = false, isNoinline = false, isHidden = false, isAssignable = false
).also { parameter ->
parameter.parent = this
}
extensionReceiverParameter = method.function.extensionReceiverParameter?.copyTo(this)

val shift = valueParameters.size
valueParameters = method.function.valueParameters.map {
it.copyTo(this, index = it.index + shift, type = it.type.substitute(substitutionMap))
}
}
}

private val specializedEqualsCache = storageManager.createCacheWithNotNullValues<IrClass, IrSimpleFunction>()
fun getSpecializedEqualsMethod(irClass: IrClass, irBuiltIns: IrBuiltIns): IrSimpleFunction {
require(irClass.isInlineOrSealedInline)
Expand Down Expand Up @@ -256,7 +296,11 @@ class MemoizedInlineClassReplacements(
}
}
valueParameters = newValueParameters
context.remapMultiFieldValueClassStructure(function, this, parametersMappingOrNull = null)
context.multiFieldValueClassReplacements.run {
bindingNewFunctionToParameterTemplateStructure[function]?.also {
bindingNewFunctionToParameterTemplateStructure[this@buildReplacement] = it
}
}
}

private fun buildReplacement(
Expand Down Expand Up @@ -289,6 +333,12 @@ class MemoizedInlineClassReplacements(
// The [updateFrom] call will set the modality to FINAL for constructors, while the JVM backend would use OPEN here.
modality = Modality.OPEN
}
if (function is IrSimpleFunction && function.modality == Modality.ABSTRACT &&
function.parentAsClass.isSealedInline &&
replacementOrigin == JvmLoweredDeclarationOrigin.STATIC_INLINE_CLASS_REPLACEMENT
) {
modality = Modality.OPEN
}
origin = when {
function.origin == IrDeclarationOrigin.GENERATED_SINGLE_FIELD_VALUE_CLASS_MEMBER ->
JvmLoweredDeclarationOrigin.INLINE_CLASS_GENERATED_IMPL_METHOD
Expand All @@ -300,7 +350,49 @@ class MemoizedInlineClassReplacements(
replacementOrigin
}
name = InlineClassAbi.mangledNameFor(function, mangleReturnTypes, useOldManglingScheme)
}.apply {
if (function is IrSimpleFunction) {
copyPropertyIfNeeded(function)
}

body()
}

override fun getReplacementForRegularClassConstructor(constructor: IrConstructor): IrConstructor? = null

private fun IrSimpleFunction.copyPropertyIfNeeded(function: IrSimpleFunction) {
val propertySymbol = function.correspondingPropertySymbol
if (propertySymbol != null) {
val property = commonBuildProperty(propertySymbol)
when (function.withoutReceiver()) {
propertySymbol.owner.getter?.withoutReceiver() -> property.getter = this
propertySymbol.owner.setter?.withoutReceiver() -> property.setter = this
else -> error("Orphaned property getter/setter: ${function.render()}")
}
}
}

class SimpleFunctionWithoutReceiver(
val function: IrSimpleFunction
) {
override fun equals(other: Any?): Boolean {
if (other === this) return true
if (other !is SimpleFunctionWithoutReceiver) return false
return function.name == other.function.name &&
function.typeParameters == other.function.typeParameters &&
function.returnType == other.function.returnType &&
function.extensionReceiverParameter == other.function.extensionReceiverParameter &&
function.valueParameters == other.function.valueParameters
}

override fun hashCode(): Int {
return function.name.hashCode() xor
function.typeParameters.hashCode() xor
function.returnType.hashCode() xor
function.extensionReceiverParameter.hashCode() xor
function.valueParameters.hashCode()
}
}
}

fun IrSimpleFunction.withoutReceiver() = MemoizedInlineClassReplacements.SimpleFunctionWithoutReceiver(this)
Loading

0 comments on commit ab3e9a2

Please sign in to comment.