Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,22 +1,56 @@
package io.temporal.internal.activity;

import io.temporal.activity.ActivityExecutionContext;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.Deque;
import java.util.Map;
import java.util.WeakHashMap;

/**
* Thread local store of the context object passed to an activity implementation. Avoid using this
* class directly.
* Thread-local / virtual-thread-aware store of the context object passed to an activity
* implementation. Avoid using this class directly.
*
* @author fateev
* <p>Uses a per-thread stack so nested sets/unsets are handled correctly. Platform threads use
* ThreadLocal; virtual threads use a WeakHashMap keyed by Thread to avoid leaking memory when
* virtual threads die.
*
* @author fateev (adapted)
*/
final class CurrentActivityExecutionContext {
public final class CurrentActivityExecutionContext {

private static final ThreadLocal<Deque<ActivityExecutionContext>> PLATFORM_STACK =
ThreadLocal.withInitial(ArrayDeque::new);

private static final ThreadLocal<ActivityExecutionContext> CURRENT = new ThreadLocal<>();
private static final Map<Thread, Deque<ActivityExecutionContext>> VIRTUAL_STACKS =
Collections.synchronizedMap(new WeakHashMap<>());

private static Deque<ActivityExecutionContext> getStackForCurrentThread() {
Thread t = Thread.currentThread();
if (isVirtualThread(t)) {
Deque<ActivityExecutionContext> d =
VIRTUAL_STACKS.computeIfAbsent(t, k -> new ArrayDeque<>());
return d;
} else {
return PLATFORM_STACK.get();
}
}

private static boolean isVirtualThread(Thread t) {
try {
t.getClass().getMethod("isVirtual", boolean.class);
return true;
} catch (NoSuchMethodException e) {
return false;
}
}

/**
* This is used by activity implementation to get access to the current ActivityExecutionContext
*/
public static ActivityExecutionContext get() {
ActivityExecutionContext result = CURRENT.get();
Deque<ActivityExecutionContext> stack = getStackForCurrentThread();
ActivityExecutionContext result = stack.peek();
if (result == null) {
throw new IllegalStateException(
"ActivityExecutionContext can be used only inside of activity "
Expand All @@ -26,21 +60,49 @@ public static ActivityExecutionContext get() {
}

public static boolean isSet() {
return CURRENT.get() != null;
Deque<ActivityExecutionContext> stack = getStackForCurrentThread();
return stack.peek() != null;
}

/**
* Pushes the provided context for the current thread. Null context is rejected. We allow nested
* sets (push semantics) to support nested interceptors / wrappers.
*/
public static void set(ActivityExecutionContext context) {
if (context == null) {
throw new IllegalArgumentException("null context");
}
if (CURRENT.get() != null) {
throw new IllegalStateException("current already set");
}
CURRENT.set(context);
Deque<ActivityExecutionContext> stack = getStackForCurrentThread();
stack.push(context);
}

/**
* Pops the current context for the thread. If the stack becomes empty, clear the storage for the
* thread to allow GC (remove ThreadLocal or remove map entry for virtual threads).
*/
public static void unset() {
CURRENT.set(null);
Thread t = Thread.currentThread();
if (isVirtualThread(t)) {
synchronized (VIRTUAL_STACKS) {
Deque<ActivityExecutionContext> stack = VIRTUAL_STACKS.get(t);
if (stack == null || stack.isEmpty()) {
return;
}
stack.pop();
if (stack.isEmpty()) {
VIRTUAL_STACKS.remove(t);
}
}
} else {
Deque<ActivityExecutionContext> stack = PLATFORM_STACK.get();
if (stack == null || stack.isEmpty()) {
return;
}
stack.pop();
if (stack.isEmpty()) {
PLATFORM_STACK.remove();
}
}
}

private CurrentActivityExecutionContext() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
public class WorkflowRunTokenTest {
private static final ObjectWriter ow =
new ObjectMapper().registerModule(new Jdk8Module()).writer();
private static final ObjectReader or =
new ObjectMapper().registerModule(new Jdk8Module()).reader();
private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding();

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package io.temporal.internal.activity;

import static org.junit.Assert.*;

import io.temporal.activity.ActivityExecutionContext;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Assume;
import org.junit.Test;

public class CurrentActivityExecutionContextTest {

private static ActivityExecutionContext proxyContext() {
InvocationHandler handler = (proxy, method, args) -> null;
return (ActivityExecutionContext)
Proxy.newProxyInstance(
ActivityExecutionContext.class.getClassLoader(),
new Class[] {ActivityExecutionContext.class},
handler);
}

@Test
public void platformThreadNestedSetUnsetBehavior() {
ActivityExecutionContext ctx1 = proxyContext();
ActivityExecutionContext ctx2 = proxyContext();

assertFalse(CurrentActivityExecutionContext.isSet());
assertThrows(IllegalStateException.class, CurrentActivityExecutionContext::get);

CurrentActivityExecutionContext.set(ctx1);
assertTrue(CurrentActivityExecutionContext.isSet());
assertSame("should return ctx1", ctx1, CurrentActivityExecutionContext.get());

CurrentActivityExecutionContext.set(ctx2);
assertTrue(CurrentActivityExecutionContext.isSet());
assertSame("should return ctx2 (top of stack)", ctx2, CurrentActivityExecutionContext.get());

CurrentActivityExecutionContext.unset();
assertTrue(CurrentActivityExecutionContext.isSet());
assertSame("after popping, should return ctx1", ctx1, CurrentActivityExecutionContext.get());

CurrentActivityExecutionContext.unset();
assertFalse(CurrentActivityExecutionContext.isSet());
assertThrows(
"get() should throw after final unset",
IllegalStateException.class,
CurrentActivityExecutionContext::get);
}

@Test
public void virtualThreadNestedSetUnsetBehavior_ifSupported() throws Exception {
boolean supportsVirtual;
try {
Thread.class.getMethod("startVirtualThread", Runnable.class);
supportsVirtual = true;
} catch (NoSuchMethodException e) {
supportsVirtual = false;
}

Assume.assumeTrue("Virtual threads not supported in this JVM; skipping", supportsVirtual);

AtomicReference<Throwable> failure = new AtomicReference<>(null);
AtomicReference<ActivityExecutionContext> seenAfterFirstSet = new AtomicReference<>(null);
AtomicReference<ActivityExecutionContext> seenAfterSecondSet = new AtomicReference<>(null);
AtomicReference<Boolean> seenIsSetAfterFinalUnset = new AtomicReference<>(null);

Thread vt =
Thread.startVirtualThread(
() -> {
try {
ActivityExecutionContext vctx1 = proxyContext();
ActivityExecutionContext vctx2 = proxyContext();

assertFalse(CurrentActivityExecutionContext.isSet());
try {
CurrentActivityExecutionContext.get();
fail("get() should have thrown when no context is set");
} catch (IllegalStateException expected) {
}

CurrentActivityExecutionContext.set(vctx1);
seenAfterFirstSet.set(CurrentActivityExecutionContext.get());

CurrentActivityExecutionContext.set(vctx2);
seenAfterSecondSet.set(CurrentActivityExecutionContext.get());

CurrentActivityExecutionContext.unset();
ActivityExecutionContext afterPop = CurrentActivityExecutionContext.get();
if (afterPop != vctx1) {
throw new AssertionError("after pop expected vctx1 but got " + afterPop);
}

CurrentActivityExecutionContext.unset();
seenIsSetAfterFinalUnset.set(CurrentActivityExecutionContext.isSet());
try {
CurrentActivityExecutionContext.get();
throw new AssertionError("get() should have thrown after final unset");
} catch (IllegalStateException expected) {
}
} catch (Throwable t) {
failure.set(t);
}
});

vt.join();

if (failure.get() != null) {
Throwable t = failure.get();
if (t instanceof AssertionError) {
throw (AssertionError) t;
} else {
throw new RuntimeException(t);
}
}

assertNotNull("virtual thread did not record first set", seenAfterFirstSet.get());
assertNotNull("virtual thread did not record second (nested) set", seenAfterSecondSet.get());
assertFalse("expected context to be unset at the end", seenIsSetAfterFinalUnset.get());
}
}