From eeda816d0e7d173f0463e84fc859a8ca5dec6aa7 Mon Sep 17 00:00:00 2001 From: TbirdDuncan Date: Fri, 13 Feb 2026 17:11:33 -0600 Subject: [PATCH] adds virtual thread support for activity execution context > updates CurrentActivityExecutionContext.java to support that mission > adds test for CurrentActivityExecutionContextTest > removed unused code in temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java --- .../CurrentActivityExecutionContext.java | 86 +++++++++++-- .../internal/nexus/WorkflowRunTokenTest.java | 2 - .../CurrentActivityExecutionContextTest.java | 121 ++++++++++++++++++ 3 files changed, 195 insertions(+), 14 deletions(-) create mode 100644 temporal-sdk/src/virtualThreadTests/java/io/temporal/internal/activity/CurrentActivityExecutionContextTest.java diff --git a/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java b/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java index 7be9fcee63..219d8aba5b 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/activity/CurrentActivityExecutionContext.java @@ -1,22 +1,56 @@ package io.temporal.internal.activity; import io.temporal.activity.ActivityExecutionContext; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Deque; +import java.util.Map; +import java.util.WeakHashMap; /** - * Thread local store of the context object passed to an activity implementation. Avoid using this - * class directly. + * Thread-local / virtual-thread-aware store of the context object passed to an activity + * implementation. Avoid using this class directly. * - * @author fateev + *

Uses a per-thread stack so nested sets/unsets are handled correctly. Platform threads use + * ThreadLocal; virtual threads use a WeakHashMap keyed by Thread to avoid leaking memory when + * virtual threads die. + * + * @author fateev (adapted) */ -final class CurrentActivityExecutionContext { +public final class CurrentActivityExecutionContext { + + private static final ThreadLocal> PLATFORM_STACK = + ThreadLocal.withInitial(ArrayDeque::new); - private static final ThreadLocal CURRENT = new ThreadLocal<>(); + private static final Map> VIRTUAL_STACKS = + Collections.synchronizedMap(new WeakHashMap<>()); + + private static Deque getStackForCurrentThread() { + Thread t = Thread.currentThread(); + if (isVirtualThread(t)) { + Deque d = + VIRTUAL_STACKS.computeIfAbsent(t, k -> new ArrayDeque<>()); + return d; + } else { + return PLATFORM_STACK.get(); + } + } + + private static boolean isVirtualThread(Thread t) { + try { + t.getClass().getMethod("isVirtual", boolean.class); + return true; + } catch (NoSuchMethodException e) { + return false; + } + } /** * This is used by activity implementation to get access to the current ActivityExecutionContext */ public static ActivityExecutionContext get() { - ActivityExecutionContext result = CURRENT.get(); + Deque stack = getStackForCurrentThread(); + ActivityExecutionContext result = stack.peek(); if (result == null) { throw new IllegalStateException( "ActivityExecutionContext can be used only inside of activity " @@ -26,21 +60,49 @@ public static ActivityExecutionContext get() { } public static boolean isSet() { - return CURRENT.get() != null; + Deque stack = getStackForCurrentThread(); + return stack.peek() != null; } + /** + * Pushes the provided context for the current thread. Null context is rejected. We allow nested + * sets (push semantics) to support nested interceptors / wrappers. + */ public static void set(ActivityExecutionContext context) { if (context == null) { throw new IllegalArgumentException("null context"); } - if (CURRENT.get() != null) { - throw new IllegalStateException("current already set"); - } - CURRENT.set(context); + Deque stack = getStackForCurrentThread(); + stack.push(context); } + /** + * Pops the current context for the thread. If the stack becomes empty, clear the storage for the + * thread to allow GC (remove ThreadLocal or remove map entry for virtual threads). + */ public static void unset() { - CURRENT.set(null); + Thread t = Thread.currentThread(); + if (isVirtualThread(t)) { + synchronized (VIRTUAL_STACKS) { + Deque stack = VIRTUAL_STACKS.get(t); + if (stack == null || stack.isEmpty()) { + return; + } + stack.pop(); + if (stack.isEmpty()) { + VIRTUAL_STACKS.remove(t); + } + } + } else { + Deque stack = PLATFORM_STACK.get(); + if (stack == null || stack.isEmpty()) { + return; + } + stack.pop(); + if (stack.isEmpty()) { + PLATFORM_STACK.remove(); + } + } } private CurrentActivityExecutionContext() {} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java b/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java index fbf14d217a..776fe790cf 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/nexus/WorkflowRunTokenTest.java @@ -11,8 +11,6 @@ public class WorkflowRunTokenTest { private static final ObjectWriter ow = new ObjectMapper().registerModule(new Jdk8Module()).writer(); - private static final ObjectReader or = - new ObjectMapper().registerModule(new Jdk8Module()).reader(); private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding(); @Test diff --git a/temporal-sdk/src/virtualThreadTests/java/io/temporal/internal/activity/CurrentActivityExecutionContextTest.java b/temporal-sdk/src/virtualThreadTests/java/io/temporal/internal/activity/CurrentActivityExecutionContextTest.java new file mode 100644 index 0000000000..407b80a8a1 --- /dev/null +++ b/temporal-sdk/src/virtualThreadTests/java/io/temporal/internal/activity/CurrentActivityExecutionContextTest.java @@ -0,0 +1,121 @@ +package io.temporal.internal.activity; + +import static org.junit.Assert.*; + +import io.temporal.activity.ActivityExecutionContext; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Proxy; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Assume; +import org.junit.Test; + +public class CurrentActivityExecutionContextTest { + + private static ActivityExecutionContext proxyContext() { + InvocationHandler handler = (proxy, method, args) -> null; + return (ActivityExecutionContext) + Proxy.newProxyInstance( + ActivityExecutionContext.class.getClassLoader(), + new Class[] {ActivityExecutionContext.class}, + handler); + } + + @Test + public void platformThreadNestedSetUnsetBehavior() { + ActivityExecutionContext ctx1 = proxyContext(); + ActivityExecutionContext ctx2 = proxyContext(); + + assertFalse(CurrentActivityExecutionContext.isSet()); + assertThrows(IllegalStateException.class, CurrentActivityExecutionContext::get); + + CurrentActivityExecutionContext.set(ctx1); + assertTrue(CurrentActivityExecutionContext.isSet()); + assertSame("should return ctx1", ctx1, CurrentActivityExecutionContext.get()); + + CurrentActivityExecutionContext.set(ctx2); + assertTrue(CurrentActivityExecutionContext.isSet()); + assertSame("should return ctx2 (top of stack)", ctx2, CurrentActivityExecutionContext.get()); + + CurrentActivityExecutionContext.unset(); + assertTrue(CurrentActivityExecutionContext.isSet()); + assertSame("after popping, should return ctx1", ctx1, CurrentActivityExecutionContext.get()); + + CurrentActivityExecutionContext.unset(); + assertFalse(CurrentActivityExecutionContext.isSet()); + assertThrows( + "get() should throw after final unset", + IllegalStateException.class, + CurrentActivityExecutionContext::get); + } + + @Test + public void virtualThreadNestedSetUnsetBehavior_ifSupported() throws Exception { + boolean supportsVirtual; + try { + Thread.class.getMethod("startVirtualThread", Runnable.class); + supportsVirtual = true; + } catch (NoSuchMethodException e) { + supportsVirtual = false; + } + + Assume.assumeTrue("Virtual threads not supported in this JVM; skipping", supportsVirtual); + + AtomicReference failure = new AtomicReference<>(null); + AtomicReference seenAfterFirstSet = new AtomicReference<>(null); + AtomicReference seenAfterSecondSet = new AtomicReference<>(null); + AtomicReference seenIsSetAfterFinalUnset = new AtomicReference<>(null); + + Thread vt = + Thread.startVirtualThread( + () -> { + try { + ActivityExecutionContext vctx1 = proxyContext(); + ActivityExecutionContext vctx2 = proxyContext(); + + assertFalse(CurrentActivityExecutionContext.isSet()); + try { + CurrentActivityExecutionContext.get(); + fail("get() should have thrown when no context is set"); + } catch (IllegalStateException expected) { + } + + CurrentActivityExecutionContext.set(vctx1); + seenAfterFirstSet.set(CurrentActivityExecutionContext.get()); + + CurrentActivityExecutionContext.set(vctx2); + seenAfterSecondSet.set(CurrentActivityExecutionContext.get()); + + CurrentActivityExecutionContext.unset(); + ActivityExecutionContext afterPop = CurrentActivityExecutionContext.get(); + if (afterPop != vctx1) { + throw new AssertionError("after pop expected vctx1 but got " + afterPop); + } + + CurrentActivityExecutionContext.unset(); + seenIsSetAfterFinalUnset.set(CurrentActivityExecutionContext.isSet()); + try { + CurrentActivityExecutionContext.get(); + throw new AssertionError("get() should have thrown after final unset"); + } catch (IllegalStateException expected) { + } + } catch (Throwable t) { + failure.set(t); + } + }); + + vt.join(); + + if (failure.get() != null) { + Throwable t = failure.get(); + if (t instanceof AssertionError) { + throw (AssertionError) t; + } else { + throw new RuntimeException(t); + } + } + + assertNotNull("virtual thread did not record first set", seenAfterFirstSet.get()); + assertNotNull("virtual thread did not record second (nested) set", seenAfterSecondSet.get()); + assertFalse("expected context to be unset at the end", seenIsSetAfterFinalUnset.get()); + } +}