Research data available for everyone.

boolean_ngram_dataset.py 1.1KB

1234567891011121314151617181920212223242526272829303132
  1. from torch.utils.data import Dataset
  2. import os
  3. from random import shuffle
  4. import numpy as np
  5. import torch
  6. import scipy.sparse
  7. class BooleanNGramDataset(Dataset):
  8. def __init__(self, csv_filepath: str):
  9. self.all_files = []
  10. with open(csv_filepath, "r") as f:
  11. lines = f.readlines()
  12. for line in lines:
  13. filepath, label = line.strip().split(",")
  14. self.all_files.append((filepath, int(label)))
  15. shuffle(self.all_files)
  16. def __len__(self):
  17. return len(self.all_files)
  18. def __getitem__(self, index):
  19. to_load, y = self.all_files[index]
  20. # Step 1: Load the .npz file into a sparse matrix
  21. sparse_matrix = scipy.sparse.load_npz(to_load)
  22. # Step 2: Convert the sparse matrix to a dense matrix (e.g., using toarray())
  23. dense_matrix = sparse_matrix.toarray() # You can also use .todense() if needed
  24. # Step 3: Convert the dense matrix to a PyTorch tensor
  25. x = torch.tensor(dense_matrix, dtype=torch.float)
  26. x = x.squeeze()
  27. return x, torch.tensor(y)

Powered by TurnKey Linux.