Research data available for everyone.

evaluate_malware_detector.py 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import argparse
  2. import torch
  3. import sys
  4. sys.path.append("../")
  5. from boolean_classifier.datasets.boolean_ngram_dataset import BooleanNGramDataset
  6. from boolean_classifier.datasets.ngram_dataset import NGramDataset
  7. from boolean_classifier.architectures.ffnn import FFNN
  8. from torch.utils.data import DataLoader
  9. import multiprocessing
  10. import json
  11. import os
  12. import torch.nn
  13. from torch.optim.lr_scheduler import _LRScheduler
  14. from torch.utils.data import DataLoader
  15. from tqdm import tqdm
  16. from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
  17. import joblib
  18. def load_configuration(configuration_filepath: str) -> dict:
  19. with open(configuration_filepath, "r") as configuration_file:
  20. configuration = json.load(configuration_file)
  21. return configuration
  22. def evaluate(model: torch.nn.Module, dataloader: DataLoader) -> tuple[list, list]:
  23. y_trues = []
  24. y_preds = []
  25. device = next(model.parameters()).device
  26. model = model.eval()
  27. with torch.no_grad():
  28. for x, y in tqdm(dataloader):
  29. if feature_selector is not None:
  30. x = torch.Tensor(feature_selector.transform(x))
  31. x, y = x.to(device), y.to(device)
  32. outputs = model.predict(x)
  33. y_pred = outputs.argmax(dim=1)
  34. y_trues.extend(y.cpu())
  35. y_preds.extend(y_pred.cpu())
  36. return y_trues, y_preds
  37. def save_results(y_trues: list, y_preds: list, output_filepath: str):
  38. acc = accuracy_score(y_trues, y_preds)
  39. precision = precision_score(y_trues, y_preds)
  40. recall = recall_score(y_trues, y_preds)
  41. f1 = f1_score(y_trues, y_preds)
  42. cm = confusion_matrix(y_trues, y_preds)
  43. with open(output_filepath, "w") as output_file:
  44. output_file.write("Accuracy: {}\n".format(acc))
  45. output_file.write("Precision: {}\n".format(precision))
  46. output_file.write("Recall: {}\n".format(recall))
  47. output_file.write("F1: {}\n".format(f1))
  48. output_file.write("Confusion Matrix: {}\n".format(cm))
  49. if __name__ == "__main__":
  50. parser = argparse.ArgumentParser(description='Evaluate malware detector')
  51. parser.add_argument("evaluation_file",
  52. type=str,
  53. help="Evaluation file containing the hashes and labels of the benign and malicious samples"
  54. )
  55. parser.add_argument("dataset_type",
  56. type=str,
  57. help="Type of dataset: {Boolean, EMBER}"
  58. )
  59. parser.add_argument("configuration_file",
  60. type=str,
  61. help="Configuration file containing the hyperparameters of the model"
  62. )
  63. parser.add_argument("output_file",
  64. type=str,
  65. help="File to where to store the results",
  66. )
  67. parser.add_argument("--batch_size",
  68. type=int,
  69. help="Batch size for training",
  70. default=32
  71. )
  72. args = parser.parse_args()
  73. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  74. print("Device: ", device)
  75. num_workers = max(multiprocessing.cpu_count() - 4, multiprocessing.cpu_count() // 2 + 1)
  76. configuration = load_configuration(args.configuration_file)
  77. if args.dataset_type == "BooleanBigrams":
  78. dataset = BooleanNGramDataset(args.evaluation_file)
  79. elif args.dataset_type == "Bigrams":
  80. dataset = NGramDataset(args.evaluation_file)
  81. else:
  82. raise NotImplementedError("Only Boolean dataset is currently supported")
  83. dataloader = DataLoader(
  84. dataset,
  85. batch_size=args.batch_size,
  86. num_workers=num_workers,
  87. )
  88. model = FFNN(configuration)
  89. model = model.to(device)
  90. model.load_state_dict(torch.load(os.path.join(configuration["model_path"], "model.pth"), weights_only=True))
  91. model.eval()
  92. if configuration["feature_selector"] is not None:
  93. feature_selector = joblib.load(configuration["feature_selector"])
  94. else:
  95. feature_selector = None
  96. y_trues, y_preds = evaluate(model, dataloader)
  97. save_results(y_trues, y_preds, args.output_file)

Powered by TurnKey Linux.