Research data available for everyone.

ngram_dataset.py 878B

123456789101112131415161718192021222324252627282930
  1. from torch.utils.data import Dataset
  2. import os
  3. from random import shuffle
  4. import numpy as np
  5. import torch
  6. import scipy.sparse
  7. class NGramDataset(Dataset):
  8. def __init__(self, csv_filepath: str):
  9. self.all_files = []
  10. with open(csv_filepath, "r") as f:
  11. lines = f.readlines()
  12. for line in lines:
  13. filepath, label = line.strip().split(",")
  14. self.all_files.append((filepath, int(label)))
  15. shuffle(self.all_files)
  16. def __len__(self):
  17. return len(self.all_files)
  18. def __getitem__(self, index):
  19. to_load, y = self.all_files[index]
  20. # Step 1: Load the .npz file
  21. matrix = np.load(to_load)["arr_0"]
  22. # Step 2: Convert the dense matrix to a PyTorch tensor
  23. x = torch.tensor(matrix, dtype=torch.float)
  24. x = x.squeeze()
  25. return x, torch.tensor(y)

Powered by TurnKey Linux.