Research data available for everyone.

train_malware_detector.py 9.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import argparse
  2. import copy
  3. import torch
  4. import sys
  5. sys.path.append("../")
  6. from boolean_classifier.datasets.boolean_ngram_dataset import BooleanNGramDataset
  7. from boolean_classifier.datasets.ngram_dataset import NGramDataset
  8. from boolean_classifier.architectures.ffnn import FFNN
  9. from torch.utils.data import DataLoader
  10. import multiprocessing
  11. import json
  12. import os
  13. import torch.nn
  14. from torch.optim.lr_scheduler import _LRScheduler
  15. from torch.utils.data import DataLoader
  16. from tqdm import tqdm
  17. import joblib
  18. class EarlyStoppingPyTorchTrainer:
  19. """Trainer for PyTorch models with early stopping."""
  20. def __init__(self, optimizer: torch.optim.Optimizer, epochs: int = 5,
  21. loss: torch.nn.Module = None, scheduler: _LRScheduler = None, feature_selector = None) -> None:
  22. """
  23. Create PyTorch trainer.
  24. Parameters
  25. ----------
  26. optimizer : torch.optim.Optimizer
  27. Optimizer to use for training the model.
  28. epochs : int, optional
  29. Number of epochs, by default 5.
  30. loss : torch.nn.Module, optional
  31. Loss to minimize, by default None.
  32. scheduler : _LRScheduler, optional
  33. Scheduler for the optimizer, by default None.
  34. """
  35. self._epochs = epochs
  36. self._optimizer = optimizer
  37. self._loss = loss if loss is not None else torch.nn.CrossEntropyLoss()
  38. self._scheduler = scheduler
  39. self.feature_selector = feature_selector
  40. self.training_losses = []
  41. self.training_accuracies = []
  42. self.validation_losses = []
  43. self.validation_accuracies = []
  44. def train(self, model: torch.nn.Module,
  45. train_loader: DataLoader,
  46. val_loader: DataLoader,
  47. patience: int) -> torch.nn.Module:
  48. """
  49. Train model with given loaders and early stopping.
  50. Parameters
  51. ----------
  52. model : torch.nn.Module
  53. Pytorch model to be trained.
  54. train_loader : DataLoader
  55. Train data loader.
  56. val_loader : DataLoader
  57. Validation data loader.
  58. patience : int
  59. Number of epochs to wait before early stopping.
  60. Returns
  61. -------
  62. torch.nn.Module
  63. Trained model.
  64. """
  65. best_loss = float("inf")
  66. best_model = None
  67. patience_counter = 0
  68. for _ in range(self._epochs):
  69. model = self.fit(model, train_loader)
  70. val_loss = self.validate(model, val_loader)
  71. if val_loss <= best_loss:
  72. best_loss = val_loss
  73. best_model = copy.deepcopy(model)
  74. patience_counter = 0
  75. else:
  76. patience_counter += 1
  77. if patience_counter >= patience:
  78. break
  79. return best_model
  80. def fit(self,
  81. model: torch.nn.Module,
  82. dataloader: DataLoader) -> torch.nn.Module:
  83. """
  84. Train model for one epoch with given loader.
  85. Parameters
  86. ----------
  87. model : torch.nn.Module
  88. Pytorch model to be trained.
  89. dataloader : DataLoader
  90. Train data loader.
  91. Returns
  92. -------
  93. torch.nn.Module
  94. Trained model.
  95. """
  96. device = next(model.parameters()).device
  97. model = model.train()
  98. model = model.to(device)
  99. running_loss = 0.0
  100. train_total = 0
  101. train_correct = 0
  102. for x, y in tqdm(dataloader):
  103. if self.feature_selector is not None:
  104. x = torch.Tensor(self.feature_selector.transform(x))
  105. x, y = x.to(device), y.to(device)
  106. self._optimizer.zero_grad()
  107. outputs = model(x)
  108. loss = self._loss(outputs, y)
  109. loss.backward()
  110. self._optimizer.step()
  111. running_loss += loss.item()
  112. y_preds = outputs.softmax(dim=1).argmax(dim=1)
  113. train_total += y.size(0)
  114. train_correct += (y_preds == y).sum().item()
  115. self.training_losses.append(running_loss / train_total)
  116. self.training_accuracies.append(train_correct / train_total)
  117. if self._scheduler is not None:
  118. self._scheduler.step()
  119. return model
  120. def validate(self,
  121. model: torch.nn.Module,
  122. dataloader: DataLoader) -> float:
  123. """
  124. Validate model with given loader.
  125. Parameters
  126. ----------
  127. model : torch.nn.Module
  128. Pytorch model to be balidated.
  129. dataloader : DataLoader
  130. Validation data loader.
  131. Returns
  132. -------
  133. float
  134. Validation loss of the model.
  135. """
  136. running_loss = 0
  137. val_total = 0
  138. val_correct = 0
  139. device = next(model.parameters()).device
  140. model = model.eval()
  141. model = model.to(device)
  142. with torch.no_grad():
  143. for x, y in tqdm(dataloader):
  144. if self.feature_selector is not None:
  145. x = torch.Tensor(self.feature_selector.transform(x))
  146. x, y = x.to(device), y.to(device)
  147. outputs = model(x)
  148. loss = self._loss(outputs, y)
  149. running_loss += loss.item()
  150. y_preds = outputs.softmax(dim=1).argmax(dim=1)
  151. val_total += y.size(0)
  152. val_correct += (y_preds == y).sum().item()
  153. self.validation_losses.append(running_loss / val_total)
  154. self.validation_accuracies.append(val_correct / val_total)
  155. return loss
  156. def save_results(trainer: EarlyStoppingPyTorchTrainer, configuration: dict):
  157. results = {
  158. "training_losses": trainer.training_losses,
  159. "training_accuracies": trainer.training_accuracies,
  160. "validation_losses": trainer.validation_losses,
  161. "validation_accuracies": trainer.validation_accuracies
  162. }
  163. with open(os.path.join(configuration["model_path"], "results.json"), "w") as output_file:
  164. json.dump(results, output_file)
  165. def load_configuration(configuration_filepath: str) -> dict:
  166. with open(configuration_filepath, "r") as configuration_file:
  167. configuration = json.load(configuration_file)
  168. return configuration
  169. if __name__ == "__main__":
  170. parser = argparse.ArgumentParser(description='Train malware detector')
  171. parser.add_argument("training_file",
  172. type=str,
  173. help="Training file containing the hashes and labels of the benign and malicious samples"
  174. )
  175. parser.add_argument("validation_file",
  176. type=str,
  177. help="Validation file containing the hashes and labels of the benign and malicious samples"
  178. )
  179. parser.add_argument("dataset_type",
  180. type=str,
  181. help="Type of dataset: {BooleanBigrams, Bigrams, EMBER}"
  182. )
  183. parser.add_argument("configuration_file",
  184. type=str,
  185. help="Configuration file containing the hyperparameters of the model"
  186. )
  187. parser.add_argument("--batch_size",
  188. type=int,
  189. help="Batch size for training",
  190. default=32
  191. )
  192. parser.add_argument("--num_epochs",
  193. type=int,
  194. help="Max epochs",
  195. default=50
  196. )
  197. parser.add_argument("--patience",
  198. type=int,
  199. help="Patience for early stopping",
  200. default=5
  201. )
  202. args = parser.parse_args()
  203. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  204. print("Device: ", device)
  205. num_workers = max(multiprocessing.cpu_count() - 4, multiprocessing.cpu_count() // 2 + 1)
  206. if args.dataset_type == "BooleanBigrams":
  207. training_dataset = BooleanNGramDataset(args.training_file)
  208. validation_dataset = BooleanNGramDataset(args.validation_file)
  209. elif args.dataset_type == "Bigrams":
  210. training_dataset = NGramDataset(args.training_file)
  211. validation_dataset = NGramDataset(args.validation_file)
  212. else:
  213. raise NotImplementedError("Only Boolean dataset is currently supported")
  214. training_dataloader = DataLoader(
  215. training_dataset,
  216. batch_size=args.batch_size,
  217. num_workers=num_workers,
  218. )
  219. validation_dataloader = DataLoader(
  220. validation_dataset,
  221. batch_size=args.batch_size,
  222. num_workers=num_workers,
  223. )
  224. configuration = load_configuration(args.configuration_file)
  225. model = FFNN(configuration)
  226. model = model.to(device)
  227. if configuration["feature_selector"] is not None:
  228. feature_selector = joblib.load(configuration["feature_selector"])
  229. else:
  230. feature_selector = None
  231. criterion = torch.nn.CrossEntropyLoss()
  232. optimizer = torch.optim.Adam(model.parameters())
  233. trainer = EarlyStoppingPyTorchTrainer(
  234. optimizer,
  235. epochs=args.num_epochs,
  236. loss=criterion,
  237. feature_selector=feature_selector
  238. )
  239. model = trainer.train(
  240. model,
  241. training_dataloader,
  242. validation_dataloader,
  243. args.patience
  244. )
  245. if not os.path.exists(configuration["model_path"]):
  246. os.makedirs(configuration["model_path"])
  247. torch.save(model.state_dict(), os.path.join(configuration["model_path"],"model.pth"))
  248. save_results(trainer, configuration)

Powered by TurnKey Linux.