Skip to content

Commit

Permalink
upgrades threat_intel virustotal to version 3
Browse files Browse the repository at this point in the history
  • Loading branch information
ytonui committed Apr 22, 2020
1 parent 27b9934 commit f9ec163
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
3 changes: 2 additions & 1 deletion tests/virustotal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
import testify as T
from mock import patch
from mock import ANY

from threat_intel.virustotal import VirusTotalApi

Expand Down Expand Up @@ -30,7 +31,7 @@ def _test_api_call(self, call, endpoint, request, expected_query_params, api_res
request_mock.multi_get.return_value = api_response
result = call(request)
param_list = [self.vt.BASE_DOMAIN + endpoint.format(param) for param in expected_query_params]
request_mock.multi_get.assert_called_with(param_list)
request_mock.multi_get.assert_called_with(param_list, file_download=ANY)
T.assert_equal(result, expected_result)

def test_get_file_reports(self):
Expand Down
12 changes: 8 additions & 4 deletions threat_intel/util/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,23 @@ def __init__(
),
)

def multi_get(self, urls, query_params=None, to_json=True):
def multi_get(self, urls, query_params=None, to_json=True, file_download=False):
"""Issue multiple GET requests.
Args:
urls - A string URL or list of string URLs
query_params - None, a dict, or a list of dicts representing the query params
to_json - A boolean, should the responses be returned as JSON blobs
file_download - A boolean, whether a file download is expected
Returns:
a list of dicts if to_json is set of requests.response otherwise.
Raises:
InvalidRequestError - Can not decide how many requests to issue.
"""
return self._multi_request(
MultiRequest._VERB_GET, urls, query_params,
data=None, to_json=to_json,
data=None, to_json=to_json, file_download=file_download,
)

def multi_post(self, urls, query_params=None, data=None, to_json=True, send_as_file=False):
Expand Down Expand Up @@ -410,7 +412,7 @@ def _convert_to_json(self, response):
))
return None

def _multi_request(self, verb, urls, query_params, data, to_json=True, send_as_file=False):
def _multi_request(self, verb, urls, query_params, data, to_json=True, send_as_file=False, file_download=False):
"""Issues multiple batches of simultaneous HTTP requests and waits for responses.
Args:
Expand Down Expand Up @@ -449,8 +451,10 @@ def _multi_request(self, verb, urls, query_params, data, to_json=True, send_as_f

responses = self._wait_for_response(prepared_requests)
for response in responses:
if response:
if response and not file_download:
all_responses.append(self._convert_to_json(response) if to_json else response)
elif file_download:
all_responses.append(self._handle_file_download(response))
else:
all_responses.append(None)

Expand Down
11 changes: 6 additions & 5 deletions threat_intel/virustotal.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_file_download(self, file_hash_list):
"""
api_name = 'virustotal-file-download'
api_endpoint = 'files/{}/download'
return self._extract_all_responses(file_hash_list, api_endpoint, api_name)
return self._extract_all_responses(file_hash_list, api_endpoint, api_name, file_download=True)

@MultiRequest.error_handling
def get_file_contacted_domains(self, file_hash_list):
Expand Down Expand Up @@ -293,17 +293,18 @@ def _bulk_cache_lookup(self, api_name, keys):

return ({}, keys)

def _request_reports(self, ids, endpoint_name):
def _request_reports(self, ids, endpoint_name, file_download=False):
"""Sends multiples requests for the resources to a particular endpoint.
Args:
ids: list of the hash identifying the file.
endpoint_name: VirusTotal endpoint URL suffix.
file_download: boolean, whether a file download is expected
Returns:
A list of the responses.
"""
urls = ['{}{}'.format(self.BASE_DOMAIN, endpoint_name.format(id)) for id in ids]
return self._requests.multi_get(urls) if urls else []
return self._requests.multi_get(urls, file_download=file_download) if urls else []


def _extract_cache_id(self, response):
Expand All @@ -328,7 +329,7 @@ def _extract_cache_id(self, response):
cache_id = cache_id.split('_')[0]
return cache_id

def _extract_all_responses(self, resources, api_endpoint, api_name):
def _extract_all_responses(self, resources, api_endpoint, api_name, file_download=False):
""" Aux function to extract all the API endpoint responses.
Args:
Expand All @@ -339,7 +340,7 @@ def _extract_all_responses(self, resources, api_endpoint, api_name):
A dict with the hash as key and the VT report as value.
"""
all_responses, resources = self._bulk_cache_lookup(api_name, resources)
response_chunks = self._request_reports(resources, api_endpoint)
response_chunks = self._request_reports(resources, api_endpoint, file_download)
self._extract_response_chunks(all_responses, response_chunks, api_name)

return all_responses
Expand Down

0 comments on commit f9ec163

Please sign in to comment.