Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@

import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageDataIdentifier;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;

/**
* A registry that maintains local or remote resources that correspond to a certain set of data in
* the Tiered Storage.
*/
public class TieredStorageResourceRegistry {

private final Map<TieredStorageDataIdentifier, List<TieredStorageResource>>
registeredResources = new HashMap<>();
private final ConcurrentHashMap<
TieredStorageDataIdentifier, CopyOnWriteArrayList<TieredStorageResource>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious, why do we need the CopyOnWriteArrayList, is the introduction of ConcurrentHashMap not enough to solve this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question!

I thought the same when I initially applied the fix (e.g., only swapped out the external map for its thread-safe brethren), however realized the tests that were added would still fail.

The ConcurrentHashMap handles the thread-safety for the map operations but not for the internal values within the map. This makes it possible to have multiple separate threads acting upon the non thread-safe list, which can lead to some inconsistency:

registeredResources
      .computeIfAbsent(owner, (ignore) -> new ArrayList<>())
      // Concurrent callers could be working with the same thread-safe map, but
      // the underlying list is not thread-safe
      .add(tieredStorageResource);

Without the extra thread-safety on the list, many of the existing tests can fail with ConcurrentModificationException, NullPointerException, and lost entries (which testConcurrentRegisterWithSameIdentifier specifically checks for). Making the swap to the CopyOnWriteArrayList (or some other thread-safe collection like Collections.synchronizedList()) makes the behavior consistent.

registeredResources = new ConcurrentHashMap<>();

/**
* Register a new resource for the given owner.
Expand All @@ -43,7 +43,7 @@ public class TieredStorageResourceRegistry {
public void registerResource(
TieredStorageDataIdentifier owner, TieredStorageResource tieredStorageResource) {
registeredResources
.computeIfAbsent(owner, (ignore) -> new ArrayList<>())
.computeIfAbsent(owner, (ignore) -> new CopyOnWriteArrayList<>())
.add(tieredStorageResource);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;

import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageDataIdentifier;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.RepeatedTest;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.assertj.core.api.Assertions.assertThat;

/** Tests for {@link TieredStorageResourceRegistry}. */
class TieredStorageResourceRegistryTest {

private static final int NUM_THREADS = 10;
private static final int NUM_OPERATIONS_PER_THREAD = 100;

private TieredStorageResourceRegistry registry;
private ExecutorService executor;
private CyclicBarrier barrier;
private CountDownLatch completionLatch;
private List<Throwable> exceptions;

@BeforeEach
void setUp() {
registry = new TieredStorageResourceRegistry();
executor = Executors.newFixedThreadPool(NUM_THREADS);
barrier = new CyclicBarrier(NUM_THREADS);
completionLatch = new CountDownLatch(NUM_THREADS);
exceptions = Collections.synchronizedList(new ArrayList<>());
}

@AfterEach
void tearDown() throws Exception {
executor.shutdown();
executor.awaitTermination(10, TimeUnit.SECONDS);
}

@RepeatedTest(10)
void testConcurrentRegisterResource() throws Exception {
AtomicInteger releaseCount = new AtomicInteger(0);
TestingDataIdentifier sharedOwner = new TestingDataIdentifier(0);

runConcurrentTask(
threadId -> {
for (int i = 0; i < NUM_OPERATIONS_PER_THREAD; i++) {
registry.registerResource(
sharedOwner, () -> releaseCount.incrementAndGet());
}
});

assertNoExceptions("Concurrent registerResource() calls");

// Clear resources and verify all were registered
registry.clearResourceFor(sharedOwner);

// Verify that all registered resources were successfully release
assertThat(releaseCount.get())
.as("All registered resources should be released.")
.isEqualTo(NUM_THREADS * NUM_OPERATIONS_PER_THREAD);
}

@RepeatedTest(10)
void testConcurrentRegisterResourceWithDifferentOwners() throws Exception {
AtomicInteger successfulRegistrations = new AtomicInteger(0);

// Run multiple concurrent threads to simulate concurrent registration (with
// different owners)
runConcurrentTask(
threadId -> {
for (int i = 0; i < NUM_OPERATIONS_PER_THREAD; i++) {
TestingDataIdentifier owner =
new TestingDataIdentifier(threadId * NUM_OPERATIONS_PER_THREAD + i);
registry.registerResource(owner, () -> {});
successfulRegistrations.incrementAndGet();
}
});

assertNoExceptions("Concurrent registerResource() calls");
assertThat(successfulRegistrations.get())
.isEqualTo(NUM_THREADS * NUM_OPERATIONS_PER_THREAD);
}

@RepeatedTest(10)
void testConcurrentRegisterAndClear() throws Exception {
// Use few owners to maximize contention on the same keys across threads
final int numOwners = 5;
TestingDataIdentifier[] owners = new TestingDataIdentifier[numOwners];
for (int i = 0; i < owners.length; i++) {
owners[i] = new TestingDataIdentifier(i);
}

// Run multiple concurrent threads to simulate concurrent registration/clearing
runConcurrentTask(
threadId -> {
for (int i = 0; i < NUM_OPERATIONS_PER_THREAD; i++) {
// All threads compete for the same small set of owners
TestingDataIdentifier owner = owners[i % numOwners];

// Alternate between register and clear to maximize entropy
if (i % 2 == 0) {
registry.registerResource(owner, () -> {});
} else {
registry.clearResourceFor(owner);
}
}
});

// Verify there were no exceptions during concurrent registration/clear operations
assertNoExceptions("Concurrent registration/clearing calls");
}

private void runConcurrentTask(ThrowingIntConsumer task) throws Exception {
for (int t = 0; t < NUM_THREADS; t++) {
final int threadId = t;
executor.submit(
() -> {
try {
barrier.await();
task.accept(threadId);
} catch (Throwable e) {
exceptions.add(e);
} finally {
completionLatch.countDown();
}
});
}
completionLatch.await(30, TimeUnit.SECONDS);
}

private void assertNoExceptions(String operationDescription) {
assertThat(exceptions)
.as("Expected no exceptions during %s. Found: %s", operationDescription, exceptions)
.isEmpty();
}

@FunctionalInterface
private interface ThrowingIntConsumer {
void accept(int value) throws Exception;
}

/** Simple implementation of TieredStorageDataIdentifier for testing. */
private static class TestingDataIdentifier implements TieredStorageDataIdentifier {
private final int id;

TestingDataIdentifier(int id) {
this.id = id;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TestingDataIdentifier that = (TestingDataIdentifier) o;
return id == that.id;
}

@Override
public int hashCode() {
return Integer.hashCode(id);
}

@Override
public String toString() {
return "TestingDataIdentifier{id=" + id + "}";
}
}
}