diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 30294be43..09e3b910b 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -80,6 +80,7 @@ def make_dummy_streamline(nb_points): 'mean_curvature': np.array([1.11], dtype='f4'), 'mean_torsion': np.array([1.22], dtype='f4'), 'mean_colors': np.array([1, 0, 0], dtype='f4'), + 'clusters_labels': np.array([0, 1], dtype='i4'), } elif nb_points == 2: @@ -92,6 +93,7 @@ def make_dummy_streamline(nb_points): 'mean_curvature': np.array([2.11], dtype='f4'), 'mean_torsion': np.array([2.22], dtype='f4'), 'mean_colors': np.array([0, 1, 0], dtype='f4'), + 'clusters_labels': np.array([2, 3, 4], dtype='i4'), } elif nb_points == 5: @@ -104,6 +106,7 @@ def make_dummy_streamline(nb_points): 'mean_curvature': np.array([3.11], dtype='f4'), 'mean_torsion': np.array([3.22], dtype='f4'), 'mean_colors': np.array([0, 0, 1], dtype='f4'), + 'clusters_labels': np.array([5, 6, 7, 8], dtype='i4'), } return streamline, data_per_point, data_for_streamline @@ -119,6 +122,7 @@ def setup_module(): DATA['mean_curvature'] = [] DATA['mean_torsion'] = [] DATA['mean_colors'] = [] + DATA['clusters_labels'] = [] for nb_points in [1, 2, 5]: data = make_dummy_streamline(nb_points) streamline, data_per_point, data_for_streamline = data @@ -128,12 +132,14 @@ def setup_module(): DATA['mean_curvature'].append(data_for_streamline['mean_curvature']) DATA['mean_torsion'].append(data_for_streamline['mean_torsion']) DATA['mean_colors'].append(data_for_streamline['mean_colors']) + DATA['clusters_labels'].append(data_for_streamline['clusters_labels']) DATA['data_per_point'] = {'colors': DATA['colors'], 'fa': DATA['fa']} DATA['data_per_streamline'] = { 'mean_curvature': DATA['mean_curvature'], 'mean_torsion': DATA['mean_torsion'], 'mean_colors': DATA['mean_colors'], + 'clusters_labels': DATA['clusters_labels'], } DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4)) @@ -154,6 +160,7 @@ def setup_module(): 'mean_curvature': lambda: (e for e in DATA['mean_curvature']), 'mean_torsion': lambda: (e for e in DATA['mean_torsion']), 'mean_colors': lambda: (e for e in DATA['mean_colors']), + 'clusters_labels': lambda: (e for e in DATA['clusters_labels']), } DATA['lazy_tractogram'] = LazyTractogram( @@ -214,7 +221,10 @@ def test_per_array_dict_creation(self): data_dict = PerArrayDict(nb_streamlines, data_per_streamline) assert data_dict.keys() == data_per_streamline.keys() for k in data_dict.keys(): - assert_array_equal(data_dict[k], data_per_streamline[k]) + if isinstance(data_dict[k], np.ndarray) and np.all( + data_dict[k].shape[0] == data_dict[k].shape + ): + assert_array_equal(data_dict[k], data_per_streamline[k]) del data_dict['mean_curvature'] assert len(data_dict) == len(data_per_streamline) - 1 @@ -224,7 +234,10 @@ def test_per_array_dict_creation(self): data_dict = PerArrayDict(nb_streamlines, data_per_streamline) assert data_dict.keys() == data_per_streamline.keys() for k in data_dict.keys(): - assert_array_equal(data_dict[k], data_per_streamline[k]) + if isinstance(data_dict[k], np.ndarray) and np.all( + data_dict[k].shape[0] == data_dict[k].shape + ): + assert_array_equal(data_dict[k], data_per_streamline[k]) del data_dict['mean_curvature'] assert len(data_dict) == len(data_per_streamline) - 1 @@ -234,7 +247,10 @@ def test_per_array_dict_creation(self): data_dict = PerArrayDict(nb_streamlines, **data_per_streamline) assert data_dict.keys() == data_per_streamline.keys() for k in data_dict.keys(): - assert_array_equal(data_dict[k], data_per_streamline[k]) + if isinstance(data_dict[k], np.ndarray) and np.all( + data_dict[k].shape[0] == data_dict[k].shape + ): + assert_array_equal(data_dict[k], data_per_streamline[k]) del data_dict['mean_curvature'] assert len(data_dict) == len(data_per_streamline) - 1 @@ -261,6 +277,7 @@ def test_extend(self): 'mean_curvature': 2 * np.array(DATA['mean_curvature']), 'mean_torsion': 3 * np.array(DATA['mean_torsion']), 'mean_colors': 4 * np.array(DATA['mean_colors']), + 'clusters_labels': 5 * np.array(DATA['clusters_labels'], dtype=object), } sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) @@ -284,7 +301,8 @@ def test_extend(self): 'mean_curvature': 2 * np.array(DATA['mean_curvature']), 'mean_torsion': 3 * np.array(DATA['mean_torsion']), 'mean_colors': 4 * np.array(DATA['mean_colors']), - 'other': 5 * np.array(DATA['mean_colors']), + 'clusters_labels': 5 * np.array(DATA['clusters_labels'], dtype=object), + 'other': 6 * np.array(DATA['mean_colors']), } sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) @@ -305,6 +323,7 @@ def test_extend(self): 'mean_curvature': 2 * np.array(DATA['mean_curvature']), 'mean_torsion': 3 * np.array(DATA['mean_torsion']), 'mean_colors': 4 * np.array(DATA['mean_torsion']), + 'clusters_labels': 5 * np.array(DATA['clusters_labels'], dtype=object), } sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) with pytest.raises(ValueError): @@ -441,7 +460,10 @@ def test_lazydict_creation(self): assert is_lazy_dict(data_dict) assert data_dict.keys() == expected_keys for k in data_dict.keys(): - assert_array_equal(list(data_dict[k]), list(DATA['data_per_streamline'][k])) + if isinstance(data_dict[k], np.ndarray) and np.all( + data_dict[k].shape[0] == data_dict[k].shape + ): + assert_array_equal(list(data_dict[k]), list(DATA['data_per_streamline'][k])) assert len(data_dict) == len(DATA['data_per_streamline_func']) @@ -578,6 +600,7 @@ def test_tractogram_add_new_data(self): t.data_per_streamline['mean_curvature'] = DATA['mean_curvature'] t.data_per_streamline['mean_torsion'] = DATA['mean_torsion'] t.data_per_streamline['mean_colors'] = DATA['mean_colors'] + t.data_per_streamline['clusters_labels'] = DATA['clusters_labels'] assert_tractogram_equal(t, DATA['tractogram']) # Retrieve tractogram by their index. @@ -598,6 +621,7 @@ def test_tractogram_add_new_data(self): t.data_per_streamline['mean_curvature'] = DATA['mean_curvature'] t.data_per_streamline['mean_torsion'] = DATA['mean_torsion'] t.data_per_streamline['mean_colors'] = DATA['mean_colors'] + t.data_per_streamline['clusters_labels'] = DATA['clusters_labels'] assert_tractogram_equal(t, DATA['tractogram']) def test_tractogram_copy(self): @@ -647,14 +671,6 @@ def test_creating_invalid_tractogram(self): with pytest.raises(ValueError): Tractogram(streamlines=DATA['streamlines'], data_per_point={'scalars': scalars}) - # Inconsistent dimension for a data_per_streamline. - properties = [[1.11, 1.22], [2.11], [3.11, 3.22]] - - with pytest.raises(ValueError): - Tractogram( - streamlines=DATA['streamlines'], data_per_streamline={'properties': properties} - ) - # Too many dimension for a data_per_streamline. properties = [ np.array([[1.11], [1.22]], dtype='f4'), @@ -870,6 +886,7 @@ def test_lazy_tractogram_from_data_func(self): DATA['mean_curvature'], DATA['mean_torsion'], DATA['mean_colors'], + DATA['clusters_labels'], ] def _data_gen(): @@ -879,6 +896,7 @@ def _data_gen(): 'mean_curvature': d[3], 'mean_torsion': d[4], 'mean_colors': d[5], + 'clusters_labels': d[6], } yield TractogramItem(d[0], data_for_streamline, data_for_points) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 9e7c0f9af..5a39b415a 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -1,6 +1,7 @@ import copy import numbers -from collections.abc import MutableMapping +import types +from collections.abc import Iterable, MutableMapping from warnings import warn import numpy as np @@ -101,15 +102,28 @@ def __init__(self, n_rows=0, *args, **kwargs): super().__init__(*args, **kwargs) def __setitem__(self, key, value): - value = np.asarray(list(value)) + dtype = np.float64 + + if isinstance(value, types.GeneratorType): + value = list(value) + + if isinstance(value, np.ndarray): + dtype = value.dtype + elif not all(len(v) == len(value[0]) for v in value[1:]): + dtype = object + + value = np.asarray(value, dtype=dtype) if value.ndim == 1 and value.dtype != object: # Reshape without copy value.shape = (len(value), 1) - if value.ndim != 2: + if value.ndim != 2 and value.dtype != object: raise ValueError('data_per_streamline must be a 2D array.') + if value.dtype == object and not all(isinstance(v, Iterable) for v in value): + raise ValueError('data_per_streamline must be a 2D array') + # We make sure there is the right amount of values if 0 < self.n_rows != len(value): msg = f'The number of values ({len(value)}) should match n_elements ({self.n_rows}).'