fix Unsafe calls with local objectFieldOffset

This commit is contained in:
Volker Berlin 2023-03-02 17:23:19 +01:00
parent 57f0c9ceb9
commit dbb49e99d1
No known key found for this signature in database
GPG Key ID: 988423EF815BE4CB

View File

@ -93,8 +93,11 @@ class UnsafeManager {
@Nonnull @Nonnull
private final ClassFileLoader classFileLoader; private final ClassFileLoader classFileLoader;
private final HashMap<FunctionName, UnsafeState> unsafes = new HashMap<>(); @Nonnull
private final HashMap<FunctionName, UnsafeState> unsafes = new HashMap<>();
@Nonnull
private final HashMap<Integer, UnsafeState> localStates = new HashMap<>();
/** /**
* Create an instance of the manager * Create an instance of the manager
@ -120,6 +123,7 @@ class UnsafeManager {
*/ */
void replaceUnsafe( @Nonnull List<WasmInstruction> instructions ) throws IOException { void replaceUnsafe( @Nonnull List<WasmInstruction> instructions ) throws IOException {
// search for Unsafe function calls // search for Unsafe function calls
localStates.clear();
for( int i = 0; i < instructions.size(); i++ ) { for( int i = 0; i < instructions.size(); i++ ) {
WasmInstruction instr = instructions.get( i ); WasmInstruction instr = instructions.get( i );
switch( instr.getType() ) { switch( instr.getType() ) {
@ -154,7 +158,7 @@ class UnsafeManager {
* @throws IOException * @throws IOException
* If any I/O error occur * If any I/O error occur
*/ */
private void patch( List<WasmInstruction> instructions, int idx, WasmCallInstruction callInst ) throws IOException { private void patch( @Nonnull List<WasmInstruction> instructions, int idx, @Nonnull WasmCallInstruction callInst ) throws IOException {
FunctionName name = callInst.getFunctionName(); FunctionName name = callInst.getFunctionName();
switch( name.signatureName ) { switch( name.signatureName ) {
case "sun/misc/Unsafe.getUnsafe()Lsun/misc/Unsafe;": case "sun/misc/Unsafe.getUnsafe()Lsun/misc/Unsafe;":
@ -184,6 +188,7 @@ class UnsafeManager {
case "sun/misc/Unsafe.getLong(Ljava/lang/Object;J)J": case "sun/misc/Unsafe.getLong(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getInt(Ljava/lang/Object;J)I": case "jdk/internal/misc/Unsafe.getInt(Ljava/lang/Object;J)I":
case "jdk/internal/misc/Unsafe.getLong(Ljava/lang/Object;J)J": case "jdk/internal/misc/Unsafe.getLong(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getObject(Ljava/lang/Object;J)Ljava/lang/Object;":
patchFieldFunction( instructions, idx, callInst, name, 1 ); patchFieldFunction( instructions, idx, callInst, name, 1 );
break; break;
case "sun/misc/Unsafe.getAndAddInt(Ljava/lang/Object;JI)I": case "sun/misc/Unsafe.getAndAddInt(Ljava/lang/Object;JI)I":
@ -223,10 +228,12 @@ class UnsafeManager {
case "java/util/concurrent/atomic/AtomicReferenceFieldUpdater.compareAndSet(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Z": case "java/util/concurrent/atomic/AtomicReferenceFieldUpdater.compareAndSet(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Z":
patchFieldFunction( instructions, idx, callInst, name, 4 ); patchFieldFunction( instructions, idx, callInst, name, 4 );
break; break;
case "jdk/internal/misc/Unsafe.getLongUnaligned(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getIntUnaligned(Ljava/lang/Object;J)I":
case "jdk/internal/misc/Unsafe.getCharUnaligned(Ljava/lang/Object;JZ)C": case "jdk/internal/misc/Unsafe.getCharUnaligned(Ljava/lang/Object;JZ)C":
case "jdk/internal/misc/Unsafe.getShortUnaligned(Ljava/lang/Object;JZ)S":
case "jdk/internal/misc/Unsafe.getIntUnaligned(Ljava/lang/Object;J)I":
case "jdk/internal/misc/Unsafe.getIntUnaligned(Ljava/lang/Object;JZ)I": case "jdk/internal/misc/Unsafe.getIntUnaligned(Ljava/lang/Object;JZ)I":
case "jdk/internal/misc/Unsafe.getLongUnaligned(Ljava/lang/Object;J)J":
case "jdk/internal/misc/Unsafe.getLongUnaligned(Ljava/lang/Object;JZ)J":
patch_getLongUnaligned( instructions, idx, callInst, name ); patch_getLongUnaligned( instructions, idx, callInst, name );
break; break;
case "jdk/internal/misc/Unsafe.isBigEndian()Z": case "jdk/internal/misc/Unsafe.isBigEndian()Z":
@ -260,7 +267,7 @@ class UnsafeManager {
* @param idx * @param idx
* the index in the instructions * the index in the instructions
*/ */
private void patch_getUnsafe( List<WasmInstruction> instructions, int idx ) { private void patch_getUnsafe( @Nonnull List<WasmInstruction> instructions, int idx ) {
WasmInstruction instr = instructions.get( idx + 1 ); WasmInstruction instr = instructions.get( idx + 1 );
int to = idx + (instr.getType() == Type.Global ? 2 : 1); int to = idx + (instr.getType() == Type.Global ? 2 : 1);
@ -278,7 +285,7 @@ class UnsafeManager {
* @return the state * @return the state
*/ */
@Nonnull @Nonnull
private UnsafeState findUnsafeState( List<WasmInstruction> instructions, int idx ) { private UnsafeState findUnsafeState( @Nonnull List<WasmInstruction> instructions, int idx ) {
// find the field on which the offset is assign: long FIELD_OFFSET = UNSAFE.objectFieldOffset(... // find the field on which the offset is assign: long FIELD_OFFSET = UNSAFE.objectFieldOffset(...
WasmInstruction instr; WasmInstruction instr;
idx++; idx++;
@ -299,6 +306,11 @@ class UnsafeManager {
} }
} }
continue INSTR; continue INSTR;
case Local:
// occur with jdk.internal.misc.InnocuousThread
UnsafeState state = new UnsafeState();
localStates.put( ((WasmLocalInstruction)instr).getIndex(), state );
return state;
default: default:
throw new WasmException( "Unsupported assign operation for Unsafe field offset: " + instr.getType(), -1 ); throw new WasmException( "Unsupported assign operation for Unsafe field offset: " + instr.getType(), -1 );
} }
@ -498,35 +510,47 @@ class UnsafeManager {
Set<FunctionName> fieldNames; Set<FunctionName> fieldNames;
FunctionName fieldNameWithOffset = null; FunctionName fieldNameWithOffset = null;
UnsafeState state = null;
if( instr.getType() == Type.Global ) { if( instr.getType() == Type.Global ) {
fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName(); fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName();
fieldNames = Collections.singleton( fieldNameWithOffset ); fieldNames = Collections.singleton( fieldNameWithOffset );
} else { } else {
// java.util.concurrent.ConcurrentHashMap.tabAt() calculate a value with the field
fieldNames = new HashSet<>(); fieldNames = new HashSet<>();
int pos2 = stackValue.idx; if( instr.getType() == Type.Local ) {
stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, idx ), fieldNameParam + 1, callInst.getCodePosition() ); // occur with jdk.internal.misc.InnocuousThread
int i = stackValue.idx; state = localStates.get( ((WasmLocalInstruction)instr).getIndex() );
for( ; i < pos2; i++ ) { if( state != null ) {
instr = instructions.get( i ); fieldNameWithOffset = new FunctionName( state.typeName, state.fieldName, "" );
if( instr.getType() != Type.Global ) { }
continue; }
if( fieldNameWithOffset == null ) {
// java.util.concurrent.ConcurrentHashMap.tabAt() calculate a value with the field
int pos2 = stackValue.idx;
stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, idx ), fieldNameParam + 1, callInst.getCodePosition() );
int i = stackValue.idx;
for( ; i < pos2; i++ ) {
instr = instructions.get( i );
if( instr.getType() != Type.Global ) {
continue;
}
fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName();
fieldNames.add( fieldNameWithOffset );
} }
fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName();
fieldNames.add( fieldNameWithOffset );
} }
} }
UnsafeState state_ = state;
WatCodeSyntheticFunctionName func = WatCodeSyntheticFunctionName func =
new WatCodeSyntheticFunctionName( fieldNameWithOffset.className, '.' + fieldNameWithOffset.methodName + '.' + name.methodName, name.signature, "", (AnyType[])null ) { new WatCodeSyntheticFunctionName( fieldNameWithOffset.className, '.' + fieldNameWithOffset.methodName + '.' + name.methodName, name.signature, "", (AnyType[])null ) {
@Override @Override
protected String getCode() { protected String getCode() {
UnsafeState state = null; UnsafeState state = state_;
for(FunctionName fieldNameWithOffset : fieldNames ) { for(FunctionName fieldNameWithOffset : fieldNames ) {
state = unsafes.get( fieldNameWithOffset );
if( state != null ) { if( state != null ) {
break; break;
} }
state = unsafes.get( fieldNameWithOffset );
} }
if( state == null ) { if( state == null ) {
if( functions.isFinish() ) { if( functions.isFinish() ) {
@ -632,16 +656,20 @@ class UnsafeManager {
case "getInt": case "getInt":
case "getLong": case "getLong":
return "local.get 1" // THIS case "getObject":
+ " struct.get " + state.typeName + ' ' + state.fieldName //
+ " return";
case "getObjectVolatile": case "getObjectVolatile":
return "local.get 1" // array if( state.fieldName != null ) {
+ " local.get 2" // the array index // field access
+ " i32.wrap_i64" // long -> int return "local.get 1" // THIS
+ " array.get " + state.typeName + " struct.get " + state.typeName + ' ' + state.fieldName //
+ " return"; + " return";
} else {
// array access
return "local.get 1" // array
+ " local.get 2" // the array index
+ " i32.wrap_i64" // long -> int
+ " array.get " + state.typeName + " return";
}
} }
throw new RuntimeException( name.signatureName ); throw new RuntimeException( name.signatureName );