Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions pina/_src/condition/condition_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pina._src.condition.condition_interface import ConditionInterface
from pina._src.core.graph import LabelBatch
from pina._src.core.label_tensor import LabelTensor
from pina._src.data.dummy_dataloader import DummyDataloader


class ConditionBase(ConditionInterface):
Expand Down Expand Up @@ -85,7 +86,8 @@ def automatic_batching_collate_fn(cls, batch):
if not batch:
return {}
instance_class = batch[0].__class__
return instance_class.create_batch(batch)
batch = instance_class.create_batch(batch)
return batch

@staticmethod
def collate_fn(batch, condition):
Expand All @@ -103,7 +105,11 @@ def collate_fn(batch, condition):
return data

def create_dataloader(
self, dataset, batch_size, shuffle, automatic_batching
self,
dataset,
batch_size,
automatic_batching,
**kwargs,
):
"""
Create a DataLoader for the condition.
Expand All @@ -114,14 +120,14 @@ def create_dataloader(
:rtype: torch.utils.data.DataLoader
"""
if batch_size == len(dataset):
pass # will be updated in the near future
return DummyDataloader(dataset)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=(
partial(self.collate_fn, condition=self)
if not automatic_batching
else self.automatic_batching_collate_fn
),
batch_size=batch_size,
**kwargs,
)
33 changes: 16 additions & 17 deletions pina/_src/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
test_size=0.0,
val_size=0.0,
compile=None,
repeat=None,
batching_mode="common_batch_size",
automatic_batching=None,
num_workers=None,
pin_memory=None,
Expand All @@ -61,9 +61,9 @@ def __init__(
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.data_module.PinaDataModule` class. Default is
:param str batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``. Default is ``"common_batch_size"``.
``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset
Expand All @@ -87,7 +87,7 @@ def __init__(
train_size=train_size,
test_size=test_size,
val_size=val_size,
repeat=repeat,
batching_mode=batching_mode,
automatic_batching=automatic_batching,
compile=compile,
)
Expand Down Expand Up @@ -127,8 +127,6 @@ def __init__(
UserWarning,
)

repeat = repeat if repeat is not None else False

automatic_batching = (
automatic_batching if automatic_batching is not None else False
)
Expand All @@ -144,7 +142,7 @@ def __init__(
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
batching_mode=batching_mode,
automatic_batching=automatic_batching,
pin_memory=pin_memory,
num_workers=num_workers,
Expand Down Expand Up @@ -182,7 +180,7 @@ def _create_datamodule(
test_size,
val_size,
batch_size,
repeat,
batching_mode,
automatic_batching,
pin_memory,
num_workers,
Expand All @@ -201,8 +199,9 @@ def _create_datamodule(
:param float val_size: The percentage of elements to include in the
validation dataset.
:param int batch_size: The number of samples per batch to load.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param str batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data
Expand Down Expand Up @@ -232,7 +231,7 @@ def _create_datamodule(
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
batching_mode=batching_mode,
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
Expand Down Expand Up @@ -284,7 +283,7 @@ def _check_input_consistency(
train_size,
test_size,
val_size,
repeat,
batching_mode,
automatic_batching,
compile,
):
Expand All @@ -298,8 +297,9 @@ def _check_input_consistency(
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param str batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool compile: If ``True``, the model is compiled before training.
Expand All @@ -309,8 +309,7 @@ def _check_input_consistency(
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
if repeat is not None:
check_consistency(repeat, bool)
check_consistency(batching_mode, str)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None:
Expand Down
58 changes: 58 additions & 0 deletions pina/_src/data/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Aggregator for multiple dataloaders.
"""


class _Aggregator:
"""
The class :class:`_Aggregator` is responsible for aggregating multiple
dataloaders into a single iterable object. It supports different batching
modes to accommodate various training requirements.
"""

def __init__(self, dataloaders, batching_mode):
"""
Initialization of the :class:`_Aggregator` class.

:param dataloaders: A dictionary mapping condition names to their
respective dataloaders.
:type dataloaders: dict[str, DataLoader]
:param batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``.
:type batching_mode: str
"""
self.dataloaders = dataloaders
self.batching_mode = batching_mode

def __len__(self):
"""
Return the length of the aggregated dataloader.

:return: The length of the aggregated dataloader.
:rtype: int
"""
return max(len(dl) for dl in self.dataloaders.values())

def __iter__(self):
"""
Return an iterator over the aggregated dataloader.

:return: An iterator over the aggregated dataloader.
:rtype: iterator
"""
if self.batching_mode == "separate_conditions":
for name, dl in self.dataloaders.items():
for batch in dl:
yield {name: batch}
return
iterators = {name: iter(dl) for name, dl in self.dataloaders.items()}
for _ in range(len(self)):
batch = {}
for name, it in iterators.items():
try:
batch[name] = next(it)
except StopIteration:
iterators[name] = iter(self.dataloaders[name])
batch[name] = next(iterators[name])
yield batch
178 changes: 178 additions & 0 deletions pina/_src/data/creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
Module defining the Creator class, responsible for creating dataloaders
for multiple conditions with various batching strategies.
"""

import torch
from torch.utils.data import RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler


class _Creator:
"""
The class :class:`_Creator` is responsible for creating dataloaders for
multiple conditions based on specified batching strategies. It supports
different batching modes to accommodate various training requirements.
"""

def __init__(
self,
batching_mode,
batch_size,
shuffle,
automatic_batching,
num_workers,
pin_memory,
conditions,
):
"""
Initialization of the :class:`_Creator` class.

:param batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``.
:type batching_mode: str
:param batch_size: The batch size to use for dataloaders. If
``batching_mode`` is ``"proportional"``, this represents the total
batch size across all conditions.
:type batch_size: int | None
:param shuffle: Whether to shuffle the data in the dataloaders.
:type shuffle: bool
:param automatic_batching: Whether to use automatic batching in the
dataloaders.
:type automatic_batching: bool
:param num_workers: The number of worker processes to use for data
loading.
:type num_workers: int
:param pin_memory: Whether to pin memory in the dataloaders.
:type pin_memory: bool
:param conditions: A dictionary mapping condition names to their
respective condition objects.
:type conditions: dict[str, Condition]
"""
self.batching_mode = batching_mode
self.batch_size = batch_size
self.shuffle = shuffle
self.automatic_batching = automatic_batching
self.num_workers = num_workers
self.pin_memory = pin_memory
self.conditions = conditions

def _define_sampler(self, dataset, shuffle):
if torch.distributed.is_initialized():
return DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return RandomSampler(dataset)
return SequentialSampler(dataset)

def _compute_batch_sizes(self, datasets):
"""
Compute batch sizes for each condition based on the specified
batching mode.

:param datasets: A dictionary mapping condition names to their
respective datasets.
:type datasets: dict[str, Dataset]
:return: A dictionary mapping condition names to their computed batch
sizes.
:rtype: dict[str, int]
"""
batch_sizes = {}
if self.batching_mode == "common_batch_size":
for name in datasets.keys():
if self.batch_size is None:
batch_sizes[name] = len(datasets[name])
else:
batch_sizes[name] = min(
self.batch_size, len(datasets[name])
)
return batch_sizes
if self.batching_mode == "proportional":
return self._compute_proportional_batch_sizes(datasets)
if self.batching_mode == "separate_conditions":
for name in datasets.keys():
condition = self.conditions[name]
if self.batch_size is None:
batch_sizes[name] = len(datasets[name])
else:
batch_sizes[name] = min(
self.batch_size, len(datasets[name])
)
return batch_sizes
raise ValueError(f"Unknown batching mode: {self.batching_mode}")

def _compute_proportional_batch_sizes(self, datasets):
"""
Compute batch sizes for each condition proportionally based on the
size of their datasets.
:param datasets: A dictionary mapping condition names to their
respective datasets.
:type datasets: dict[str, Dataset]
:return: A dictionary mapping condition names to their computed batch
sizes.
:rtype: dict[str, int]
"""
# Compute number of elements per dataset
elements_per_dataset = {
dataset_name: len(dataset)
for dataset_name, dataset in datasets.items()
}
# Compute the total number of elements
total_elements = sum(el for el in elements_per_dataset.values())
# Compute the portion of each dataset
portion_per_dataset = {
name: el / total_elements
for name, el in elements_per_dataset.items()
}
# Compute batch size per dataset. Ensure at least 1 element per
# dataset.
batch_size_per_dataset = {
name: max(1, int(portion * self.batch_size))
for name, portion in portion_per_dataset.items()
}
# Adjust batch sizes to match the specified total batch size
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
if self.batch_size > tot_el_per_batch:
difference = self.batch_size - tot_el_per_batch
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] += 1
difference -= 1
if self.batch_size < tot_el_per_batch:
difference = tot_el_per_batch - self.batch_size
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] -= 1
difference -= 1
return batch_size_per_dataset

def __call__(self, datasets):
"""
Create dataloaders for each condition based on the specified batching
mode.
:param datasets: A dictionary mapping condition names to their
respective datasets.
:type datasets: dict[str, Dataset]
:return: A dictionary mapping condition names to their created
dataloaders.
:rtype: dict[str, DataLoader]
"""
# Compute batch sizes per condition based on batching_mode
batch_sizes = self._compute_batch_sizes(datasets)
dataloaders = {}
for name, dataset in datasets.items():
dataloaders[name] = self.conditions[name].create_dataloader(
dataset=dataset,
batch_size=batch_sizes[name],
automatic_batching=self.automatic_batching,
sampler=self._define_sampler(dataset, self.shuffle),
num_workers=self.num_workers,
pin_memory=self.pin_memory,
)
return dataloaders
Loading