blob: 6916a26cd48f40fa4abbd148092b951310655486 [file] [log] [blame]
# Copyright 2015 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unittests for the mdns.py module."""
from __future__ import print_function
import dpkt
import select
import socket
from chromite.lib import cros_test_lib
from chromite.lib import mdns
class BadService(object):
"""A bad service used to signal the test to generate bad mDNS response."""
class mDnsTestCase(cros_test_lib.MockTestCase):
"""Base test case that mocks the network API to return mDNS services."""
def setUp(self):
self.socket_class_mock = self.PatchObject(socket, 'socket')
self.select_mock = self.PatchObject(select, 'select')
def _BuildmDnsResponse(self, service):
"""Return valid mDNS response for specified service.
Args:
service: Namedtuple that contains the service settings.
"""
answers = []
if service.ip:
answers.append(
dpkt.dns.DNS.RR(type=dpkt.dns.DNS_A, ip=socket.inet_aton(service.ip)))
if service.ptrname:
answers.append(
dpkt.dns.DNS.RR(type=dpkt.dns.DNS_PTR, ptrname=service.ptrname))
if service.hostname:
answers.append(
dpkt.dns.DNS.RR(type=dpkt.dns.DNS_SRV, srvname=service.hostname,
priority=1, weight=1, port=service.port))
if service.text:
text = ['%s=%s' % (key, value) for key, value in service.text.iteritems()]
answers.append(dpkt.dns.DNS.RR(type=dpkt.dns.DNS_TXT, text=text))
return dpkt.dns.DNS(op=dpkt.dns.DNS_QUERY,
rcode=dpkt.dns.DNS_RCODE_NOERR,
q=[], an=answers, ns=[])
def _BuildBadmDnsResponse(self):
"""Return invalid mDNS response."""
return 'junk'
def _MockNetworkResponse(self, services):
"""Mock the network response to include the specified mDNS services.
Args:
services: List of mDNS services that form the network response. Use
|mdns.Service| to indicate a valid mDNS response. Use |BadService| to
indicate an invalid mDNS response.
"""
socket_mock = self.socket_class_mock.return_value
select_side_effects = []
recvfrom_side_effects = []
for service in services:
select_side_effects.append(([1], [], []))
if type(service) is mdns.Service:
mdns_response = self._BuildmDnsResponse(service)
else:
mdns_response = self._BuildBadmDnsResponse()
recvfrom_side_effects.append((str(mdns_response), ''))
select_side_effects.append(([], [], []))
self.select_mock.side_effect = select_side_effects
socket_mock.recvfrom.side_effect = recvfrom_side_effects
class mDnsFindServicesTest(mDnsTestCase):
"""Tests for FindServices()."""
def _TestmDnsResults(self, services, expected_results, should_add_func=None,
should_continue_func=None):
"""Mock out network responses and call FindServices().
Args:
services: List of services to return in mDNS responses.
expected_results: List of services to expect from FindServices().
should_add_func: See |should_add_func| argument in FindServices().
should_continue_func: See |should_continue_func| argument in
FindServices().
"""
self._MockNetworkResponse(services)
results = mdns.FindServices('127.0.0.1',
'a.local',
should_add_func=should_add_func,
should_continue_func=should_continue_func)
self.assertEqual(results, expected_results)
def testFindServices(self):
"""Test finding all mDNS services."""
services = [mdns.Service('test1.local', '10.0.0.1', 1234, 'test1.a.local',
{'name': 'test'}),
mdns.Service('test2.local', '10.0.0.2', 1234, 'test2.a.local',
{'name': 'test2'})]
self._TestmDnsResults(services, services)
def testFindServicesIncompleteResponse(self):
"""Test finding all mDNS services but ignoring incomplete services."""
services = [mdns.Service('test1.local', '10.0.0.1', 1234, 'test1.a.local',
{'name': 'test'}),
mdns.Service('test2.local', '10.0.0.2', 1234, None, None),
mdns.Service('test3.local', '10.0.0.3', 1234, 'test3.a.local',
{'name': 'test3'})]
expected_results = [services[0], services[2]]
self._TestmDnsResults(services, expected_results)
def testFindOneService(self):
"""Test finding a specific service."""
services = [mdns.Service('test1.local', '10.0.0.1', 1234, 'test1.a.local',
{'name': 'test'}),
mdns.Service('test2.local', '10.0.0.2', 1234, 'test2.a.local',
{'name': 'test2'})]
expected_results = [services[1]]
should_add_func = lambda x: x.hostname == services[1].hostname
should_continue_func = lambda x: x.hostname != services[1].hostname
self._TestmDnsResults(services, expected_results,
should_add_func=should_add_func,
should_continue_func=should_continue_func)
def testFindSeveralServices(self):
"""Test early-exit condition.
Accept all entries and make sure nothing is returned after matched entry.
"""
services = [mdns.Service('test1.local', '10.0.0.1', 1234, 'test1.a.local',
{'name': 'test'}),
mdns.Service('test2.local', '10.0.0.2', 1234, 'test2.a.local',
{'name': 'test2'}),
mdns.Service('test3.local', '10.0.0.3', 1234, 'test3.a.local',
{'name': 'test3'}),
mdns.Service('test4.local', '10.0.0.4', 1234, 'test4.a.local',
{'name': 'test4'})]
expected_results = [services[0], services[1]]
should_add_func = lambda x: True
should_continue_func = lambda x: x.hostname != services[1].hostname
self._TestmDnsResults(services, expected_results,
should_add_func=should_add_func,
should_continue_func=should_continue_func)
def testBadResponse(self):
"""Test bad mDNS response."""
services = [mdns.Service('test1.local', '10.0.0.1', 1234, 'test1.a.local',
{'name': 'test'}),
BadService(),
mdns.Service('test3.local', '10.0.0.3', 1234, 'test3.a.local',
{'name': 'test3'})]
expected_results = [services[0], services[2]]
self._TestmDnsResults(services, expected_results)