-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
I have encountered an error during training bidcell model with a stereoseq dataset.
In [46]: model.train()
2024-08-23 15:18:44,725 INFO Initialising model
Number of genes: 515
2024-08-23 15:18:45,184 INFO Preparing data
Loaded nuclei
(11092, 10089)
8563 patches available
2024-08-23 15:18:48,752 INFO Total number of training examples: 8563
2024-08-23 15:18:48,754 INFO Begin training
Epoch = 1 lr = 1e-05
Epoch[1/1], Step[7], Loss:1351.2664
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[46], line 1
----> 1 model.train()
File ~/anaconda3/envs/bidcell/lib/python3.10/site-packages/bidcell/BIDCellModel.py:114, in BIDCellModel.train(self)
111 def train(self) -> None:
112 """Train the model.
113 """
--> 114 train(self.config)
File ~/anaconda3/envs/bidcell/lib/python3.10/site-packages/bidcell/model/train.py:170, in train(config)
167 cur_lr = optimizer.param_groups[0]["lr"]
168 print("\nEpoch =", (epoch + 1), " lr =", cur_lr)
--> 170 for step_epoch, (
171 batch_x313,
172 batch_n,
173 batch_sa,
174 batch_pos,
175 batch_neg,
176 coords_h1,
177 coords_w1,
178 nucl_aug,
179 expr_aug_sum,
180 ) in enumerate(train_loader):
181 # Permute channels axis to batch axis
182 # torch.Size([1, patch_size, patch_size, 313, n_cells]) to [n_cells, 313, patch_size, patch_size]
183 batch_x313 = batch_x313[0, :, :, :, :].permute(3, 2, 0, 1)
184 batch_sa = batch_sa.permute(3, 0, 1, 2)
File ~/anaconda3/envs/bidcell/lib/python3.10/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self)
625 if self._sampler_iter is None:
626 # TODO(https://github.com/pytorch/pytorch/issues/76750)
627 self._reset() # type: ignore[call-arg]
--> 628 data = self._next_data()
629 self._num_yielded += 1
630 if self._dataset_kind == _DatasetKind.Iterable and \
631 self._IterableDataset_len_called is not None and \
632 self._num_yielded > self._IterableDataset_len_called:
File ~/anaconda3/envs/bidcell/lib/python3.10/site-packages/torch/utils/data/dataloader.py:671, in _SingleProcessDataLoaderIter._next_data(self)
669 def _next_data(self):
670 index = self._next_index() # may raise StopIteration
--> 671 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
672 if self._pin_memory:
673 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~/anaconda3/envs/bidcell/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
56 data = self.dataset.__getitems__(possibly_batched_index)
57 else:
---> 58 data = [self.dataset[idx] for idx in possibly_batched_index]
59 else:
60 data = self.dataset[possibly_batched_index]
File ~/anaconda3/envs/bidcell/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58, in <listcomp>(.0)
56 data = self.dataset.__getitems__(possibly_batched_index)
57 else:
---> 58 data = [self.dataset[idx] for idx in possibly_batched_index]
59 else:
60 data = self.dataset[possibly_batched_index]
File ~/anaconda3/envs/bidcell/lib/python3.10/site-packages/bidcell/model/dataio/dataset_input.py:220, in DataProcessing.__getitem__(self, index)
217 assert expr.shape[0] == self.patch_size, print(expr.shape[0])
218 assert expr.shape[1] == self.patch_size, print(expr.shape[1])
--> 220 img = np.concatenate((expr, np.expand_dims(nucl, -1)), -1)
222 if self.isTraining:
223 img, _, _ = self.augment_data(img)
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 32 and the array at index 1 has size 0
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels