fix Unsafe calls with local objectFieldOffset

This commit is contained in:
Volker Berlin 2023-03-02 17:23:19 +01:00
parent 57f0c9ceb9
commit dbb49e99d1
No known key found for this signature in database
GPG Key ID: 988423EF815BE4CB

View File

@ -93,8 +93,11 @@ class UnsafeManager {
@Nonnull
private final ClassFileLoader classFileLoader;
private final HashMap<FunctionName, UnsafeState> unsafes = new HashMap<>();
@Nonnull
private final HashMap<FunctionName, UnsafeState> unsafes = new HashMap<>();
@Nonnull
private final HashMap<Integer, UnsafeState> localStates = new HashMap<>();
/**
* Create an instance of the manager
@ -120,6 +123,7 @@ class UnsafeManager {
*/
void replaceUnsafe( @Nonnull List<WasmInstruction> instructions ) throws IOException {
// search for Unsafe function calls
localStates.clear();
for( int i = 0; i < instructions.size(); i++ ) {
WasmInstruction instr = instructions.get( i );
switch( instr.getType() ) {
@ -154,7 +158,7 @@ class UnsafeManager {
* @throws IOException
* If any I/O error occur
*/
private void patch( List<WasmInstruction> instructions, int idx, WasmCallInstruction callInst ) throws IOException {
private void patch( @Nonnull List<WasmInstruction> instructions, int idx, @Nonnull WasmCallInstruction callInst ) throws IOException {
FunctionName name = callInst.getFunctionName();
switch( name.signatureName ) {
case "sun/misc/Unsafe.getUnsafe()Lsun/misc/Unsafe;":
@ -184,6 +188,7 @@ class UnsafeManager {
case "sun/misc/Unsafe.getLong(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getInt(Ljava/lang/Object;J)I":
case "jdk/internal/misc/Unsafe.getLong(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getObject(Ljava/lang/Object;J)Ljava/lang/Object;":
patchFieldFunction( instructions, idx, callInst, name, 1 );
break;
case "sun/misc/Unsafe.getAndAddInt(Ljava/lang/Object;JI)I":
@ -223,10 +228,12 @@ class UnsafeManager {
case "java/util/concurrent/atomic/AtomicReferenceFieldUpdater.compareAndSet(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Z":
patchFieldFunction( instructions, idx, callInst, name, 4 );
break;
case "jdk/internal/misc/Unsafe.getLongUnaligned(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getIntUnaligned(Ljava/lang/Object;J)I":
case "jdk/internal/misc/Unsafe.getCharUnaligned(Ljava/lang/Object;JZ)C":
case "jdk/internal/misc/Unsafe.getShortUnaligned(Ljava/lang/Object;JZ)S":
case "jdk/internal/misc/Unsafe.getIntUnaligned(Ljava/lang/Object;J)I":
case "jdk/internal/misc/Unsafe.getIntUnaligned(Ljava/lang/Object;JZ)I":
case "jdk/internal/misc/Unsafe.getLongUnaligned(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getLongUnaligned(Ljava/lang/Object;JZ)J":
patch_getLongUnaligned( instructions, idx, callInst, name );
break;
case "jdk/internal/misc/Unsafe.isBigEndian()Z":
@ -260,7 +267,7 @@ class UnsafeManager {
* @param idx
* the index in the instructions
*/
private void patch_getUnsafe( List<WasmInstruction> instructions, int idx ) {
private void patch_getUnsafe( @Nonnull List<WasmInstruction> instructions, int idx ) {
WasmInstruction instr = instructions.get( idx + 1 );
int to = idx + (instr.getType() == Type.Global ? 2 : 1);
@ -278,7 +285,7 @@ class UnsafeManager {
* @return the state
*/
@Nonnull
private UnsafeState findUnsafeState( List<WasmInstruction> instructions, int idx ) {
private UnsafeState findUnsafeState( @Nonnull List<WasmInstruction> instructions, int idx ) {
// find the field on which the offset is assign: long FIELD_OFFSET = UNSAFE.objectFieldOffset(...
WasmInstruction instr;
idx++;
@ -299,6 +306,11 @@ class UnsafeManager {
}
}
continue INSTR;
case Local:
// occur with jdk.internal.misc.InnocuousThread
UnsafeState state = new UnsafeState();
localStates.put( ((WasmLocalInstruction)instr).getIndex(), state );
return state;
default:
throw new WasmException( "Unsupported assign operation for Unsafe field offset: " + instr.getType(), -1 );
}
@ -498,35 +510,47 @@ class UnsafeManager {
Set<FunctionName> fieldNames;
FunctionName fieldNameWithOffset = null;
UnsafeState state = null;
if( instr.getType() == Type.Global ) {
fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName();
fieldNames = Collections.singleton( fieldNameWithOffset );
} else {
// java.util.concurrent.ConcurrentHashMap.tabAt() calculate a value with the field
fieldNames = new HashSet<>();
int pos2 = stackValue.idx;
stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, idx ), fieldNameParam + 1, callInst.getCodePosition() );
int i = stackValue.idx;
for( ; i < pos2; i++ ) {
instr = instructions.get( i );
if( instr.getType() != Type.Global ) {
continue;
if( instr.getType() == Type.Local ) {
// occur with jdk.internal.misc.InnocuousThread
state = localStates.get( ((WasmLocalInstruction)instr).getIndex() );
if( state != null ) {
fieldNameWithOffset = new FunctionName( state.typeName, state.fieldName, "" );
}
}
if( fieldNameWithOffset == null ) {
// java.util.concurrent.ConcurrentHashMap.tabAt() calculate a value with the field
int pos2 = stackValue.idx;
stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, idx ), fieldNameParam + 1, callInst.getCodePosition() );
int i = stackValue.idx;
for( ; i < pos2; i++ ) {
instr = instructions.get( i );
if( instr.getType() != Type.Global ) {
continue;
}
fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName();
fieldNames.add( fieldNameWithOffset );
}
fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName();
fieldNames.add( fieldNameWithOffset );
}
}
UnsafeState state_ = state;
WatCodeSyntheticFunctionName func =
new WatCodeSyntheticFunctionName( fieldNameWithOffset.className, '.' + fieldNameWithOffset.methodName + '.' + name.methodName, name.signature, "", (AnyType[])null ) {
@Override
protected String getCode() {
UnsafeState state = null;
UnsafeState state = state_;
for(FunctionName fieldNameWithOffset : fieldNames ) {
state = unsafes.get( fieldNameWithOffset );
if( state != null ) {
break;
}
state = unsafes.get( fieldNameWithOffset );
}
if( state == null ) {
if( functions.isFinish() ) {
@ -632,16 +656,20 @@ class UnsafeManager {
case "getInt":
case "getLong":
return "local.get 1" // THIS
+ " struct.get " + state.typeName + ' ' + state.fieldName //
+ " return";
case "getObject":
case "getObjectVolatile":
return "local.get 1" // array
+ " local.get 2" // the array index
+ " i32.wrap_i64" // long -> int
+ " array.get " + state.typeName
+ " return";
if( state.fieldName != null ) {
// field access
return "local.get 1" // THIS
+ " struct.get " + state.typeName + ' ' + state.fieldName //
+ " return";
} else {
// array access
return "local.get 1" // array
+ " local.get 2" // the array index
+ " i32.wrap_i64" // long -> int
+ " array.get " + state.typeName + " return";
}
}
throw new RuntimeException( name.signatureName );