diff --git a/src/de/inetsoftware/jwebassembly/module/FunctionManager.java b/src/de/inetsoftware/jwebassembly/module/FunctionManager.java index 6823a2b..4c74e34 100644 --- a/src/de/inetsoftware/jwebassembly/module/FunctionManager.java +++ b/src/de/inetsoftware/jwebassembly/module/FunctionManager.java @@ -170,6 +170,10 @@ class FunctionManager { FunctionName markAsNeeded( @Nonnull FunctionName name, boolean needThisParameter ) { FunctionState state = getOrCreate( name ); if( state.state == State.None ) { + if( "sun/misc/Unsafe".equals( name.className ) ) { + // Unsafe method call will be replaces by the UnsafeManager + return name; + } if( isFinish ) { throw new WasmException( "Prepare was already finish: " + name.signatureName, -1 ); } diff --git a/src/de/inetsoftware/jwebassembly/module/JavaMethodWasmCodeBuilder.java b/src/de/inetsoftware/jwebassembly/module/JavaMethodWasmCodeBuilder.java index b4580ff..677e9dc 100644 --- a/src/de/inetsoftware/jwebassembly/module/JavaMethodWasmCodeBuilder.java +++ b/src/de/inetsoftware/jwebassembly/module/JavaMethodWasmCodeBuilder.java @@ -50,6 +50,7 @@ class JavaMethodWasmCodeBuilder extends WasmCodeBuilder { private BranchManager branchManager; + private UnsafeManager unsafeManager; /** * Create an instance * @@ -67,6 +68,7 @@ class JavaMethodWasmCodeBuilder extends WasmCodeBuilder { void init( WasmOptions options, ClassFileLoader classFileLoader ) { super.init( options, classFileLoader ); this.branchManager = new BranchManager( options, getInstructions(), getLocalVariables() ); + this.unsafeManager = new UnsafeManager( options.functions ); } /** @@ -750,6 +752,7 @@ class JavaMethodWasmCodeBuilder extends WasmCodeBuilder { } wide = false; } + unsafeManager.replaceUnsafe( getInstructions() ); branchManager.calculate(); branchManager.handle( byteCode ); // add branch operations if( returnType != null && !isEndsWithReturn() ) { diff --git a/src/de/inetsoftware/jwebassembly/module/UnsafeManager.java b/src/de/inetsoftware/jwebassembly/module/UnsafeManager.java new file mode 100644 index 0000000..8735fd9 --- /dev/null +++ b/src/de/inetsoftware/jwebassembly/module/UnsafeManager.java @@ -0,0 +1,326 @@ +/* + * Copyright 2022 Volker Berlin (i-net software) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.inetsoftware.jwebassembly.module; + +import java.util.HashMap; +import java.util.List; + +import javax.annotation.Nonnull; + +import de.inetsoftware.jwebassembly.WasmException; +import de.inetsoftware.jwebassembly.module.StackInspector.StackValue; +import de.inetsoftware.jwebassembly.module.WasmInstruction.Type; +import de.inetsoftware.jwebassembly.wasm.AnyType; +import de.inetsoftware.jwebassembly.wasm.VariableOperator; + +/** + * Replace Unsafe operations with simpler WASM operations which does not need reflections. + * + * In Java a typical Unsafe code look like: + * + *
+ * 
+ * private static final Unsafe UNSAFE = Unsafe.getUnsafe();
+ * private static final long FIELD_OFFSET = UNSAFE.objectFieldOffset(Foobar.class.getDeclaredField("field"));
+ * 
+ * ...
+ * 
+ * UNSAFE.compareAndSwapInt(this, FIELD_OFFSET, expect, update);
+ * 
+ * 
+ * + * Because WASM does not support reflection the native code of UNSAFE can't simple replaced. That this manager convert + * this to the follow pseudo code in WASM: + * + *
+ * 
+ * Foobar..compareAndSwapInt(this, FIELD_OFFSET, expect, update);
+ * 
+ * ...
+ * 
+ * boolean .compareAndSwapInt(Object obj, long fieldOffset, int expect, int update ) {
+ *     if( obj.field == expect ) {
+ *         obj.field = update;
+ *         return true;
+ *     }
+ *     return false;
+ * }
+ * 
+ * 
+ * + * @author Volker Berlin + */ +class UnsafeManager { + + @Nonnull + private final FunctionManager functions; + + private final HashMap unsafes = new HashMap<>(); + + /** + * Create an instance of the manager + * + * @param functions + * The function manager to register the synthetic functions. + */ + UnsafeManager( @Nonnull FunctionManager functions ) { + this.functions = functions; + } + + /** + * Replace any Unsafe API call with direct field access. + * + * @param instructions + * the instruction list of a function/method + */ + void replaceUnsafe( List instructions ) { + // search for Unsafe function calls + for( int i = 0; i < instructions.size(); i++ ) { + WasmInstruction instr = instructions.get( i ); + switch( instr.getType() ) { + case CallVirtual: + case Call: + WasmCallInstruction callInst = (WasmCallInstruction)instr; + if( "sun/misc/Unsafe".equals( callInst.getFunctionName().className ) ) { + patch( instructions, i, callInst ); + } + break; + default: + } + } + } + + /** + * Patch in the instruction list an Unsafe method call. It does not change the count of instructions. + * + * @param instructions + * the instruction list of a function/method + * @param idx + * the index in the instructions + * @param callInst + * the method call to Unsafe + */ + private void patch( List instructions, int idx, WasmCallInstruction callInst ) { + FunctionName name = callInst.getFunctionName(); + switch( name.signatureName ) { + case "sun/misc/Unsafe.getUnsafe()Lsun/misc/Unsafe;": + patch_getUnsafe( instructions, idx ); + break; + case "sun/misc/Unsafe.objectFieldOffset(Ljava/lang/reflect/Field;)J": + patch_objectFieldOffset( instructions, idx, callInst ); + break; + case "sun/misc/Unsafe.getAndAddInt(Ljava/lang/Object;JI)I": + case "sun/misc/Unsafe.getAndSetInt(Ljava/lang/Object;JI)I": + case "sun/misc/Unsafe.putOrderedInt(Ljava/lang/Object;JI)V": + patchFunction( instructions, idx, callInst, name, 2 ); + break; + case "sun/misc/Unsafe.compareAndSwapInt(Ljava/lang/Object;JII)Z": + patchFunction( instructions, idx, callInst, name, 3 ); + break; + default: + throw new WasmException( "Unsupported Unsafe method: " + name.signatureName, -1 ); + } + } + + /** + * Replace a call to Unsafe.getUnsafe() with a NOP operation. + * + * @param instructions + * the instruction list of a function/method + * @param idx + * the index in the instructions + */ + private void patch_getUnsafe( List instructions, int idx ) { + WasmInstruction instr = instructions.get( idx + 1 ); + + FunctionName fieldName = ((WasmGlobalInstruction)instr).getFieldName(); + unsafes.putIfAbsent( fieldName, new UnsafeState() ); + + nop( instructions, idx, idx + 2 ); + } + + /** + * Patch a method call to Unsafe.objectFieldOffset() and find the parameter for other patch operations. + * + * @param instructions + * the instruction list + * @param idx + * the index in the instructions + * @param callInst + * the method call to Unsafe + */ + private void patch_objectFieldOffset( List instructions, int idx, WasmCallInstruction callInst ) { + // find the field on which the offset is assign: long FIELD_OFFSET = UNSAFE.objectFieldOffset(... + WasmInstruction instr = instructions.get( idx + 1 ); + FunctionName fieldNameWithOffset = ((WasmGlobalInstruction)instr).getFieldName(); + UnsafeState state = unsafes.get( fieldNameWithOffset ); + if( state == null ) { + unsafes.put( fieldNameWithOffset, state = new UnsafeState() ); + } + + // objectFieldOffset() has 2 parameters THIS(Unsafe) and a Field + int from = StackInspector.findInstructionThatPushValue( instructions.subList( 0, idx ), 2, callInst.getCodePosition() ).idx; + + StackValue stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, idx ), 1, callInst.getCodePosition() ); + instr = stackValue.instr; + WasmCallInstruction fieldInst = (WasmCallInstruction)instr; + + FunctionName fieldFuncName = fieldInst.getFunctionName(); + switch( fieldFuncName.signatureName ) { + case "java/lang/Class.getDeclaredField(Ljava/lang/String;)Ljava/lang/reflect/Field;": + stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, stackValue.idx ), 1, fieldInst.getCodePosition() ); + state.fieldName = ((WasmConstStringInstruction)stackValue.instr).getValue(); + + // find the class value on which getDeclaredField is called + stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, stackValue.idx ), 1, fieldInst.getCodePosition() ); + instr = stackValue.instr; + switch( instr.getType() ) { + case Local: + int slot = ((WasmLocalInstruction)instr).getSlot(); + for( int i = stackValue.idx - 1; i >= 0; i-- ) { + instr = instructions.get( i ); + if( instr.getType() == Type.Local ) { + WasmLocalInstruction loadInstr = (WasmLocalInstruction)instr; + if( loadInstr.getSlot() == slot && loadInstr.getOperator() == VariableOperator.set ) { + stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, i ), 1, fieldInst.getCodePosition() ); + instr = stackValue.instr; + } + + } + } + break; + } + state.typeName = ((WasmConstClassInstruction)instr).getValue(); + break; + + default: + throw new WasmException( "Unsupported Unsafe method to get target field: " + fieldFuncName.signatureName, -1 ); + } + + nop( instructions, from, idx + 2 ); + } + + /** + * Patch an unsafe function that access a field + * + * @param instructions + * the instruction list + * @param idx + * the index in the instructions + * @param callInst + * the method call to Unsafe + * @param name + * the calling function + * @param fieldNameParam + * the function parameter with the field offset. This must be a long (Java signature "J"). The THIS + * parameter has the index 0. + */ + private void patchFunction( List instructions, int idx, final WasmCallInstruction callInst, FunctionName name, int fieldNameParam ) { + StackValue stackValue = StackInspector.findInstructionThatPushValue( instructions.subList( 0, idx ), fieldNameParam, callInst.getCodePosition() ); + FunctionName fieldNameWithOffset = ((WasmGlobalInstruction)stackValue.instr).getFieldName(); + WatCodeSyntheticFunctionName func = + new WatCodeSyntheticFunctionName( fieldNameWithOffset.className, '.' + name.methodName, name.signature, "", (AnyType[])null ) { + @Override + protected String getCode() { + UnsafeState state = unsafes.get( fieldNameWithOffset ); + if( state == null ) { + // we are in the scan phase. The static code was not scanned yet. + return ""; + } + switch( name.methodName ) { + case "compareAndSwapInt": + return "local.get 0" // THIS + + " struct.get " + state.typeName + ' ' + state.fieldName // + + " local.get 2" // expected + + " i32.eq" // + + " if" // + + " local.get 0" // THIS + + " local.get 3" // update + + " struct.set " + state.typeName + ' ' + state.fieldName // + + " i32.const 1" // + + " return" // + + " end" // + + " i32.const 1" // + + " return"; + + case "getAndAddInt": + return "local.get 0" // THIS + + " local.get 0" // THIS + + " struct.get " + state.typeName + ' ' + state.fieldName // + + " local.tee 3" // temp + + " local.get 2" // delta + + " i32.add" // + + " struct.set " + state.typeName + ' ' + state.fieldName // + + " local.get 3" // temp + + " return"; + + case "getAndSetInt": + return "local.get 0" // THIS + + " struct.get " + state.typeName + ' ' + state.fieldName // + + " local.get 0" // THIS + + " local.get 2" // newValue + + " struct.set " + state.typeName + ' ' + state.fieldName // + + " return"; + + case "putOrderedInt": + return "local.get 0" // THIS + + " local.get 2" // x + + " struct.set " + state.typeName + ' ' + state.fieldName; + } + + throw new RuntimeException( name.signatureName ); + } + }; + functions.markAsNeeded( func, false ); + WasmCallInstruction call = new WasmCallInstruction( func, callInst.getCodePosition(), callInst.getLineNumber(), callInst.getTypeManager(), false ); + instructions.set( idx, call ); + + // a virtual method call has also a DUP of this because we need for virtual method dispatch the parameter 2 times. + for( int i = idx; i >= 0; i-- ) { + WasmInstruction instr = instructions.get( i ); + if( instr.getType() == Type.DupThis && ((DupThis)instr).getValue() == callInst ) { + nop( instructions, i, i + 1 ); + break; + } + } + } + + /** + * Replace the instructions with NOP operations + * + * @param instructions + * the instruction list + * @param from + * starting index + * @param to + * end index + */ + private void nop( List instructions, int from, int to ) { + for( int i = from; i < to; i++ ) { + WasmInstruction instr = instructions.get( i ); + instructions.set( i, new WasmNopInstruction( instr.getCodePosition(), instr.getLineNumber() ) ); + } + } + + /** + * Hold the state from declaring of Unsafe address + */ + static class UnsafeState { + String fieldName; + + String typeName; + } +} diff --git a/test/de/inetsoftware/jwebassembly/runtime/UnsafeTest.java b/test/de/inetsoftware/jwebassembly/runtime/UnsafeTest.java new file mode 100644 index 0000000..7b1aee6 --- /dev/null +++ b/test/de/inetsoftware/jwebassembly/runtime/UnsafeTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2022 Volker Berlin (i-net software) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.inetsoftware.jwebassembly.runtime; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.ClassRule; +import org.junit.runners.Parameterized.Parameters; + +import de.inetsoftware.jwebassembly.ScriptEngine; +import de.inetsoftware.jwebassembly.WasmRule; +import de.inetsoftware.jwebassembly.api.annotation.Export; + +/** + * @author Volker Berlin + */ +public class UnsafeTest extends AbstractBaseTest { + + @ClassRule + public static WasmRule rule = new WasmRule( TestClass.class ); + + public UnsafeTest( ScriptEngine script, String method, Object[] params ) { + super( rule, script, method, params ); + } + + @Parameters( name = "{0}-{1}" ) + public static Collection data() { + ArrayList list = new ArrayList<>(); + for( ScriptEngine script : ScriptEngine.testEngines() ) { + addParam( list, script, "compareAndSwapInt" ); + addParam( list, script, "getAndAddInt" ); + addParam( list, script, "getAndSetInt" ); + addParam( list, script, "lazySet" ); + } + rule.setTestParameters( list ); + return list; + } + + static class TestClass { + + @Export + static int compareAndSwapInt() { + AtomicInteger obj = new AtomicInteger(); + if( obj.compareAndSet( 0, 25 ) ) { + return obj.get(); + } else { + return 42; + } + } + + @Export + static int getAndAddInt() { + AtomicInteger obj = new AtomicInteger(); + obj.set( 13 ); + if( obj.getAndAdd( 25 ) == 13 ) { + return obj.get(); + } else { + return 42; + } + } + + @Export + static int getAndSetInt() { + AtomicInteger obj = new AtomicInteger(); + obj.set( 13 ); + if( obj.getAndSet( 25 ) == 13 ) { + return obj.get(); + } else { + return 42; + } + } + + @Export + static int lazySet() { + AtomicInteger obj = new AtomicInteger(); + obj.lazySet( 42 ); + return obj.get(); + } + } +}