Source code for src.original.DK_OGC_AmsterdamUMC.utils.data_processing.processors.FlattenImageData
import numpy as np
import torch
import torchio
from torchio.transforms import Transform
[docs]
class FlattenImageData(Transform):
def __init__(self, **kwargs):
super().__init__(**kwargs)
[docs]
def apply_transform(self, subject):
"""
flattens image data of signals image of subject
Args:
signals: signals array to normalize
xvals: xval array
Returns:
normalized_signals: normalized signals array
"""
images_dict = subject.get_images_dict(include=self.include, exclude=self.exclude)
for image_key, image in images_dict.items():
flattened_array = self.flatten_image_data(image.numpy())
subject.add_image(torchio.Image(tensor=torch.Tensor(np.reshape(flattened_array,
(flattened_array.shape[0],
flattened_array.shape[1], 1, 1)))),
image_key)
return subject
[docs]
@staticmethod
def flatten_image_data(signals):
"""
Flattens 4D array into 2D array
Args:
signals: signals array to normalize
Returns:
normalized_signals: normalized signals array
"""
bvals, x, y, z = signals.shape
signals_array = np.reshape(signals, (bvals, x * y * z))
return signals_array