diff options
author | Arthur de Jong <arthur@arthurdejong.org> | 2013-08-17 12:32:07 +0200 |
---|---|---|
committer | Arthur de Jong <arthur@arthurdejong.org> | 2013-08-17 12:32:07 +0200 |
commit | 8a3f0f51b2406e6ee9537fdc96cadc0d3fa2194c (patch) | |
tree | 49608eb4f63bbe85ff8c730b90828430f4cf9567 | |
parent | 84d22e608b03c154d11e54ff34d7b87bf1d78cfa (diff) | |
parent | a066bcb17e1b99a42a5834d1ace6feba7c9b60b7 (diff) |
Improvements to pynslcd caching functionality
This fixes most of the existing caching functionality. Cache expiry,
negative hits and entries going away remain to be implemented.
-rw-r--r-- | configure.ac | 1 | ||||
-rw-r--r-- | pynslcd/alias.py | 41 | ||||
-rw-r--r-- | pynslcd/cache.py | 362 | ||||
-rw-r--r-- | pynslcd/ether.py | 9 | ||||
-rw-r--r-- | pynslcd/group.py | 45 | ||||
-rw-r--r-- | pynslcd/host.py | 69 | ||||
-rw-r--r-- | pynslcd/netgroup.py | 36 | ||||
-rw-r--r-- | pynslcd/network.py | 69 | ||||
-rw-r--r-- | pynslcd/passwd.py | 14 | ||||
-rw-r--r-- | pynslcd/protocol.py | 39 | ||||
-rw-r--r-- | pynslcd/rpc.py | 39 | ||||
-rw-r--r-- | pynslcd/service.py | 89 | ||||
-rw-r--r-- | pynslcd/shadow.py | 15 | ||||
-rw-r--r-- | tests/Makefile.am | 8 | ||||
-rwxr-xr-x | tests/test_pynslcd_cache.py | 459 |
15 files changed, 941 insertions, 354 deletions
diff --git a/configure.ac b/configure.ac index 9774e3a..a586173 100644 --- a/configure.ac +++ b/configure.ac @@ -61,6 +61,7 @@ AM_PROG_CC_C_O AC_USE_SYSTEM_EXTENSIONS AC_PROG_LN_S AM_PATH_PYTHON(2.5,, [:]) +AM_CONDITIONAL([HAVE_PYTHON], [test "$PYTHON" != ":"]) AM_PROG_AR # checks for tool to convert docbook to man diff --git a/pynslcd/alias.py b/pynslcd/alias.py index 46c4d6b..371ac2e 100644 --- a/pynslcd/alias.py +++ b/pynslcd/alias.py @@ -37,19 +37,40 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): + tables = ('alias_cache', 'alias_member_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `alias_cache` + ( `cn` TEXT PRIMARY KEY COLLATE NOCASE, + `mtime` TIMESTAMP NOT NULL ); + CREATE TABLE IF NOT EXISTS `alias_member_cache` + ( `alias` TEXT NOT NULL COLLATE NOCASE, + `rfc822MailMember` TEXT NOT NULL, + FOREIGN KEY(`alias`) REFERENCES `alias_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `alias_member_idx` ON `alias_member_cache`(`alias`); + ''' + retrieve_sql = ''' SELECT `alias_cache`.`cn` AS `cn`, - `alias_1_cache`.`rfc822MailMember` AS `rfc822MailMember` + `alias_member_cache`.`rfc822MailMember` AS `rfc822MailMember`, + `alias_cache`.`mtime` AS `mtime` FROM `alias_cache` - LEFT JOIN `alias_1_cache` - ON `alias_1_cache`.`alias` = `alias_cache`.`cn` - ''' - - def retrieve(self, parameters): - query = cache.Query(self.retrieve_sql, parameters) - # return results, returning the members as a list - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('rfc822MailMember', )): - yield row['cn'], row['rfc822MailMember'] + LEFT JOIN `alias_member_cache` + ON `alias_member_cache`.`alias` = `alias_cache`.`cn` + ''' + + retrieve_by = dict( + rfc822MailMember=''' + `cn` IN ( + SELECT `a`.`alias` + FROM `alias_member_cache` `a` + WHERE `a`.`rfc822MailMember` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # rfc822MailMember class AliasRequest(common.Request): diff --git a/pynslcd/cache.py b/pynslcd/cache.py index 7089d41..3974fef 100644 --- a/pynslcd/cache.py +++ b/pynslcd/cache.py @@ -27,184 +27,53 @@ import sqlite3 # TODO: probably create a config table - - -# FIXME: store the cache in the right place and make it configurable -filename = '/tmp/cache.sqlite' -dirname = os.path.dirname(filename) -if not os.path.isdir(dirname): - os.mkdir(dirname) -con = sqlite3.connect(filename, - detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False) -con.row_factory = sqlite3.Row - # FIXME: have some way to remove stale entries from the cache if all items from LDAP are queried (perhas use TTL from all request) -# set up the database -con.executescript(''' - - -- store temporary tables in memory - PRAGMA temp_store = MEMORY; - - -- disable sync() on database (corruption on disk failure) - PRAGMA synchronous = OFF; - - -- put journal in memory (corruption if crash during transaction) - PRAGMA journal_mode = MEMORY; - - -- tables for alias cache - CREATE TABLE IF NOT EXISTS `alias_cache` - ( `cn` TEXT PRIMARY KEY COLLATE NOCASE, - `mtime` TIMESTAMP NOT NULL ); - CREATE TABLE IF NOT EXISTS `alias_1_cache` - ( `alias` TEXT NOT NULL COLLATE NOCASE, - `rfc822MailMember` TEXT NOT NULL, - FOREIGN KEY(`alias`) REFERENCES `alias_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `alias_1_idx` ON `alias_1_cache`(`alias`); - - -- table for ethernet cache - CREATE TABLE IF NOT EXISTS `ether_cache` - ( `cn` TEXT NOT NULL COLLATE NOCASE, - `macAddress` TEXT NOT NULL COLLATE NOCASE, - `mtime` TIMESTAMP NOT NULL, - UNIQUE (`cn`, `macAddress`) ); - - -- table for group cache - CREATE TABLE IF NOT EXISTS `group_cache` - ( `cn` TEXT PRIMARY KEY, - `userPassword` TEXT, - `gidNumber` INTEGER NOT NULL UNIQUE, - `mtime` TIMESTAMP NOT NULL ); - CREATE TABLE IF NOT EXISTS `group_3_cache` - ( `group` TEXT NOT NULL, - `memberUid` TEXT NOT NULL, - FOREIGN KEY(`group`) REFERENCES `group_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `group_3_idx` ON `group_3_cache`(`group`); - - -- tables for host cache - CREATE TABLE IF NOT EXISTS `host_cache` - ( `cn` TEXT PRIMARY KEY COLLATE NOCASE, - `mtime` TIMESTAMP NOT NULL ); - CREATE TABLE IF NOT EXISTS `host_1_cache` - ( `host` TEXT NOT NULL COLLATE NOCASE, - `cn` TEXT NOT NULL COLLATE NOCASE, - FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `host_1_idx` ON `host_1_cache`(`host`); - CREATE TABLE IF NOT EXISTS `host_2_cache` - ( `host` TEXT NOT NULL COLLATE NOCASE, - `ipHostNumber` TEXT NOT NULL, - FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `host_2_idx` ON `host_2_cache`(`host`); - -- FIXME: this does not work as entries are never removed from the cache - CREATE TABLE IF NOT EXISTS `netgroup_cache` - ( `cn` TEXT NOT NULL, - `member` TEXT NOT NULL, - `mtime` TIMESTAMP NOT NULL, - UNIQUE (`cn`, `member`) ); +class regroup(object): - -- tables for network cache - CREATE TABLE IF NOT EXISTS `network_cache` - ( `cn` TEXT PRIMARY KEY COLLATE NOCASE, - `mtime` TIMESTAMP NOT NULL ); - CREATE TABLE IF NOT EXISTS `network_1_cache` - ( `network` TEXT NOT NULL COLLATE NOCASE, - `cn` TEXT NOT NULL COLLATE NOCASE, - FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `network_1_idx` ON `network_1_cache`(`network`); - CREATE TABLE IF NOT EXISTS `network_2_cache` - ( `network` TEXT NOT NULL, - `ipNetworkNumber` TEXT NOT NULL, - FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `network_2_idx` ON `network_2_cache`(`network`); + def __init__(self, results, group_by=None, group_column=None): + """Regroup the results in the group column by the key columns.""" + self.group_by = tuple(group_by) + self.group_column = group_column + self.it = iter(results) + self.tgtkey = self.currkey = self.currvalue = object() - -- table for passwd cache - CREATE TABLE IF NOT EXISTS `passwd_cache` - ( `uid` TEXT PRIMARY KEY, - `userPassword` TEXT, - `uidNumber` INTEGER NOT NULL UNIQUE, - `gidNumber` INTEGER NOT NULL, - `gecos` TEXT, - `homeDirectory` TEXT, - `loginShell` TEXT, - `mtime` TIMESTAMP NOT NULL ); - - -- table for protocol cache - CREATE TABLE IF NOT EXISTS `protocol_cache` - ( `cn` TEXT PRIMARY KEY, - `ipProtocolNumber` INTEGER NOT NULL, - `mtime` TIMESTAMP NOT NULL ); - CREATE TABLE IF NOT EXISTS `protocol_1_cache` - ( `protocol` TEXT NOT NULL, - `cn` TEXT NOT NULL, - FOREIGN KEY(`protocol`) REFERENCES `protocol_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `protocol_1_idx` ON `protocol_1_cache`(`protocol`); - - -- table for rpc cache - CREATE TABLE IF NOT EXISTS `rpc_cache` - ( `cn` TEXT PRIMARY KEY, - `oncRpcNumber` INTEGER NOT NULL, - `mtime` TIMESTAMP NOT NULL ); - CREATE TABLE IF NOT EXISTS `rpc_1_cache` - ( `rpc` TEXT NOT NULL, - `cn` TEXT NOT NULL, - FOREIGN KEY(`rpc`) REFERENCES `rpc_cache`(`cn`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `rpc_1_idx` ON `rpc_1_cache`(`rpc`); - - -- tables for service cache - CREATE TABLE IF NOT EXISTS `service_cache` - ( `cn` TEXT NOT NULL, - `ipServicePort` INTEGER NOT NULL, - `ipServiceProtocol` TEXT NOT NULL, - `mtime` TIMESTAMP NOT NULL, - UNIQUE (`ipServicePort`, `ipServiceProtocol`) ); - CREATE TABLE IF NOT EXISTS `service_1_cache` - ( `ipServicePort` INTEGER NOT NULL, - `ipServiceProtocol` TEXT NOT NULL, - `cn` TEXT NOT NULL, - FOREIGN KEY(`ipServicePort`) REFERENCES `service_cache`(`ipServicePort`) - ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY(`ipServiceProtocol`) REFERENCES `service_cache`(`ipServiceProtocol`) - ON DELETE CASCADE ON UPDATE CASCADE ); - CREATE INDEX IF NOT EXISTS `service_1_idx1` ON `service_1_cache`(`ipServicePort`); - CREATE INDEX IF NOT EXISTS `service_1_idx2` ON `service_1_cache`(`ipServiceProtocol`); + def keyfunc(self, row): + return tuple(row[x] for x in self.group_by) - -- table for shadow cache - CREATE TABLE IF NOT EXISTS `shadow_cache` - ( `uid` TEXT PRIMARY KEY, - `userPassword` TEXT, - `shadowLastChange` INTEGER, - `shadowMin` INTEGER, - `shadowMax` INTEGER, - `shadowWarning` INTEGER, - `shadowInactive` INTEGER, - `shadowExpire` INTEGER, - `shadowFlag` INTEGER, - `mtime` TIMESTAMP NOT NULL ); + def __iter__(self): + return self - ''') + def next(self): + # find a start row + while self.currkey == self.tgtkey: + self.currvalue = next(self.it) # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) + self.tgtkey = self.currkey + # turn the result row into a list of columns + row = list(self.currvalue) + # replace the group column + row[self.group_column] = list(self._grouper(self.tgtkey)) + return row + + def _grouper(self, tgtkey): + """Generate the group columns.""" + while self.currkey == tgtkey: + value = self.currvalue[self.group_column] + if value is not None: + yield value + self.currvalue = next(self.it) # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) class Query(object): + """Helper class to build an SQL query for the cache.""" - def __init__(self, query, parameters=None): + def __init__(self, query): self.query = query self.wheres = [] self.parameters = [] - if parameters: - for k, v in parameters.items(): - self.add_where('`%s` = ?' % k, [v]) - - def add_query(self, query): - self.query += ' ' + query def add_where(self, where, parameters): self.wheres.append(where) @@ -214,100 +83,113 @@ class Query(object): query = self.query if self.wheres: query += ' WHERE ' + ' AND '.join(self.wheres) - c = con.cursor() - return c.execute(query, self.parameters) - - -class CnAliasedQuery(Query): - - sql = ''' - SELECT `%(table)s_cache`.*, - `%(table)s_1_cache`.`cn` AS `alias` - FROM `%(table)s_cache` - LEFT JOIN `%(table)s_1_cache` - ON `%(table)s_1_cache`.`%(table)s` = `%(table)s_cache`.`cn` - ''' - - cn_join = ''' - LEFT JOIN `%(table)s_1_cache` `cn_alias` - ON `cn_alias`.`%(table)s` = `%(table)s_cache`.`cn` - ''' - - def __init__(self, table, parameters): - args = dict(table=table) - super(CnAliasedQuery, self).__init__(self.sql % args) - for k, v in parameters.items(): - if k == 'cn': - self.add_query(self.cn_join % args) - self.add_where('(`%(table)s_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)' % args, [v, v]) - else: - self.add_where('`%s` = ?' % k, [v]) - - -class RowGrouper(object): - """Pass in query results and group the results by a certain specified - list of columns.""" - - def __init__(self, results, groupby, columns): - self.groupby = groupby - self.columns = columns - self.results = itertools.groupby(results, key=self.keyfunc) - - def __iter__(self): - return self - - def keyfunc(self, row): - return tuple(row[x] for x in self.groupby) - - def next(self): - groupcols, rows = self.results.next() - tmp = dict((x, list()) for x in self.columns) - for row in rows: - for col in self.columns: - if row[col] is not None: - tmp[col].append(row[col]) - result = dict(row) - result.update(tmp) - return result + cursor = con.cursor() + return cursor.execute(query, self.parameters) class Cache(object): + """The description of the cache.""" + + retrieve_sql = None + retrieve_by = dict() + group_by = () + group_columns = () def __init__(self): - self.con = con - self.table = sys.modules[self.__module__].__name__ + self.con = _get_connection() + self.db = sys.modules[self.__module__].__name__ + if not hasattr(self, 'tables'): + self.tables = ['%s_cache' % self.db] + self.create() + + def create(self): + """Create the needed tables if neccesary.""" + self.con.executescript(self.create_sql) def store(self, *values): - """Store the values in the cache for the specified table.""" + """Store the values in the cache for the specified table. + The order of the values is the order returned by the Reques.convert() + function.""" + # split the values into simple (flat) values and one-to-many values simple_values = [] - multi_values = {} - for n, v in enumerate(values): + multi_values = [] + for v in values: if isinstance(v, (list, tuple, set)): - multi_values[n] = v + multi_values.append(v) else: simple_values.append(v) + # insert the simple values simple_values.append(datetime.datetime.now()) args = ', '.join(len(simple_values) * ('?', )) - con.execute(''' - INSERT OR REPLACE INTO %s_cache + self.con.execute(''' + INSERT OR REPLACE INTO %s VALUES (%s) - ''' % (self.table, args), simple_values) - for n, vlist in multi_values.items(): - con.execute(''' - DELETE FROM %s_%d_cache + ''' % (self.tables[0], args), simple_values) + # insert the one-to-many values + for n, vlist in enumerate(multi_values): + self.con.execute(''' + DELETE FROM %s WHERE `%s` = ? - ''' % (self.table, n, self.table), (values[0], )) - con.executemany(''' - INSERT INTO %s_%d_cache + ''' % (self.tables[n + 1], self.db), (values[0], )) + self.con.executemany(''' + INSERT INTO %s VALUES (?, ?) - ''' % (self.table, n), ((values[0], x) for x in vlist)) + ''' % (self.tables[n + 1]), ((values[0], x) for x in vlist)) def retrieve(self, parameters): - """Retrieve all items from the cache based on the parameters supplied.""" - query = Query(''' + """Retrieve all items from the cache based on the parameters + supplied.""" + query = Query(self.retrieve_sql or ''' SELECT * - FROM %s_cache - ''' % self.table, parameters) - return (list(x)[:-1] for x in query.execute(self.con)) + FROM %s + ''' % self.tables[0]) + if parameters: + for k, v in parameters.items(): + where = self.retrieve_by.get(k, '`%s`.`%s` = ?' % (self.tables[0], k)) + query.add_where(where, where.count('?') * [v]) + # group by + # FIXME: find a nice way to turn group_by and group_columns into names + results = query.execute(self.con) + group_by = list(self.group_by + self.group_columns) + for column in self.group_columns[::-1]: + group_by.pop() + results = regroup(results, group_by, column) + # strip the mtime from the results + return (list(x)[:-1] for x in results) + + def __enter__(self): + return self.con.__enter__(); + + def __exit__(self, *args): + return self.con.__exit__(*args); + + +# the connection to the sqlite database +_connection = None + + +# FIXME: make tread safe (is this needed the way the caches are initialised?) +def _get_connection(): + global _connection + if _connection is None: + filename = '/tmp/pynslcd_cache.sqlite' + dirname = os.path.dirname(filename) + if not os.path.isdir(dirname): + os.mkdir(dirname) + connection = sqlite3.connect( + filename, detect_types=sqlite3.PARSE_DECLTYPES, + check_same_thread=False) + connection.row_factory = sqlite3.Row + # initialise connection properties + connection.executescript(''' + -- store temporary tables in memory + PRAGMA temp_store = MEMORY; + -- disable sync() on database (corruption on disk failure) + PRAGMA synchronous = OFF; + -- put journal in memory (corruption if crash during transaction) + PRAGMA journal_mode = MEMORY; + ''') + _connection = connection + return _connection diff --git a/pynslcd/ether.py b/pynslcd/ether.py index d5d8c06..e5060ca 100644 --- a/pynslcd/ether.py +++ b/pynslcd/ether.py @@ -59,7 +59,14 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): - pass + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `ether_cache` + ( `cn` TEXT NOT NULL COLLATE NOCASE, + `macAddress` TEXT NOT NULL COLLATE NOCASE, + `mtime` TIMESTAMP NOT NULL, + UNIQUE (`cn`, `macAddress`) ); + ''' class EtherRequest(common.Request): diff --git a/pynslcd/group.py b/pynslcd/group.py index 2868d96..10e3423 100644 --- a/pynslcd/group.py +++ b/pynslcd/group.py @@ -75,20 +75,41 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): + tables = ('group_cache', 'group_member_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `group_cache` + ( `cn` TEXT PRIMARY KEY, + `userPassword` TEXT, + `gidNumber` INTEGER NOT NULL UNIQUE, + `mtime` TIMESTAMP NOT NULL ); + CREATE TABLE IF NOT EXISTS `group_member_cache` + ( `group` TEXT NOT NULL, + `memberUid` TEXT NOT NULL, + FOREIGN KEY(`group`) REFERENCES `group_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `group_member_idx` ON `group_member_cache`(`group`); + ''' + retrieve_sql = ''' - SELECT `cn`, `userPassword`, `gidNumber`, `memberUid` + SELECT `group_cache`.`cn` AS `cn`, `userPassword`, `gidNumber`, + `memberUid`, `mtime` FROM `group_cache` - LEFT JOIN `group_3_cache` - ON `group_3_cache`.`group` = `group_cache`.`cn` - ''' - - def retrieve(self, parameters): - query = cache.Query(self.retrieve_sql, parameters) - # return results returning the members as a set - q = itertools.groupby(query.execute(self.con), - key=lambda x: (x['cn'], x['userPassword'], x['gidNumber'])) - for k, v in q: - yield k + (set(x['memberUid'] for x in v if x['memberUid'] is not None), ) + LEFT JOIN `group_member_cache` + ON `group_member_cache`.`group` = `group_cache`.`cn` + ''' + + retrieve_by = dict( + memberUid=''' + `cn` IN ( + SELECT `a`.`group` + FROM `group_member_cache` `a` + WHERE `a`.`memberUid` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (3, ) # memberUid class GroupRequest(common.Request): diff --git a/pynslcd/host.py b/pynslcd/host.py index ffd9588..04f5337 100644 --- a/pynslcd/host.py +++ b/pynslcd/host.py @@ -34,29 +34,58 @@ class Search(search.LDAPSearch): required = ('cn', ) -class HostQuery(cache.CnAliasedQuery): +class Cache(cache.Cache): - sql = ''' + tables = ('host_cache', 'host_alias_cache', 'host_address_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `host_cache` + ( `cn` TEXT PRIMARY KEY COLLATE NOCASE, + `mtime` TIMESTAMP NOT NULL ); + CREATE TABLE IF NOT EXISTS `host_alias_cache` + ( `host` TEXT NOT NULL COLLATE NOCASE, + `cn` TEXT NOT NULL COLLATE NOCASE, + FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `host_alias_idx` ON `host_alias_cache`(`host`); + CREATE TABLE IF NOT EXISTS `host_address_cache` + ( `host` TEXT NOT NULL COLLATE NOCASE, + `ipHostNumber` TEXT NOT NULL, + FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `host_address_idx` ON `host_address_cache`(`host`); + ''' + + retrieve_sql = ''' SELECT `host_cache`.`cn` AS `cn`, - `host_1_cache`.`cn` AS `alias`, - `host_2_cache`.`ipHostNumber` AS `ipHostNumber` + `host_alias_cache`.`cn` AS `alias`, + `host_address_cache`.`ipHostNumber` AS `ipHostNumber`, + `host_cache`.`mtime` AS `mtime` FROM `host_cache` - LEFT JOIN `host_1_cache` - ON `host_1_cache`.`host` = `host_cache`.`cn` - LEFT JOIN `host_2_cache` - ON `host_2_cache`.`host` = `host_cache`.`cn` - ''' - - def __init__(self, parameters): - super(HostQuery, self).__init__('host', parameters) - - -class Cache(cache.Cache): - - def retrieve(self, parameters): - query = HostQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipHostNumber', )): - yield row['cn'], row['alias'], row['ipHostNumber'] + LEFT JOIN `host_alias_cache` + ON `host_alias_cache`.`host` = `host_cache`.`cn` + LEFT JOIN `host_address_cache` + ON `host_address_cache`.`host` = `host_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `host_cache`.`cn` = ? OR + `host_cache`.`cn` IN ( + SELECT `by_alias`.`host` + FROM `host_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ipHostNumber=''' + `host_cache`.`cn` IN ( + SELECT `by_ipHostNumber`.`host` + FROM `host_address_cache` `by_ipHostNumber` + WHERE `by_ipHostNumber`.`ipHostNumber` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, 2) # alias, ipHostNumber class HostRequest(common.Request): diff --git a/pynslcd/netgroup.py b/pynslcd/netgroup.py index 20f8779..d86e38c 100644 --- a/pynslcd/netgroup.py +++ b/pynslcd/netgroup.py @@ -42,7 +42,41 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): - pass + + tables = ('netgroup_cache', 'netgroup_triple_cache', 'netgroup_member_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `netgroup_cache` + ( `cn` TEXT PRIMARY KEY COLLATE NOCASE, + `mtime` TIMESTAMP NOT NULL ); + CREATE TABLE IF NOT EXISTS `netgroup_triple_cache` + ( `netgroup` TEXT NOT NULL COLLATE NOCASE, + `nisNetgroupTriple` TEXT NOT NULL COLLATE NOCASE, + FOREIGN KEY(`netgroup`) REFERENCES `netgroup_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `netgroup_triple_idx` ON `netgroup_triple_cache`(`netgroup`); + CREATE TABLE IF NOT EXISTS `netgroup_member_cache` + ( `netgroup` TEXT NOT NULL COLLATE NOCASE, + `memberNisNetgroup` TEXT NOT NULL, + FOREIGN KEY(`netgroup`) REFERENCES `netgroup_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `netgroup_membe_idx` ON `netgroup_member_cache`(`netgroup`); + ''' + + retrieve_sql = ''' + SELECT `netgroup_cache`.`cn` AS `cn`, + `netgroup_triple_cache`.`nisNetgroupTriple` AS `nisNetgroupTriple`, + `netgroup_member_cache`.`memberNisNetgroup` AS `memberNisNetgroup`, + `netgroup_cache`.`mtime` AS `mtime` + FROM `netgroup_cache` + LEFT JOIN `netgroup_triple_cache` + ON `netgroup_triple_cache`.`netgroup` = `netgroup_cache`.`cn` + LEFT JOIN `netgroup_member_cache` + ON `netgroup_member_cache`.`netgroup` = `netgroup_cache`.`cn` + ''' + + group_by = (0, ) # cn + group_columns = (1, 2) # nisNetgroupTriple, memberNisNetgroup class NetgroupRequest(common.Request): diff --git a/pynslcd/network.py b/pynslcd/network.py index dc91d68..01bf6c2 100644 --- a/pynslcd/network.py +++ b/pynslcd/network.py @@ -35,29 +35,58 @@ class Search(search.LDAPSearch): required = ('cn', ) -class NetworkQuery(cache.CnAliasedQuery): +class Cache(cache.Cache): - sql = ''' + tables = ('network_cache', 'network_alias_cache', 'network_address_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `network_cache` + ( `cn` TEXT PRIMARY KEY COLLATE NOCASE, + `mtime` TIMESTAMP NOT NULL ); + CREATE TABLE IF NOT EXISTS `network_alias_cache` + ( `network` TEXT NOT NULL COLLATE NOCASE, + `cn` TEXT NOT NULL COLLATE NOCASE, + FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `network_alias_idx` ON `network_alias_cache`(`network`); + CREATE TABLE IF NOT EXISTS `network_address_cache` + ( `network` TEXT NOT NULL COLLATE NOCASE, + `ipNetworkNumber` TEXT NOT NULL, + FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `network_address_idx` ON `network_address_cache`(`network`); + ''' + + retrieve_sql = ''' SELECT `network_cache`.`cn` AS `cn`, - `network_1_cache`.`cn` AS `alias`, - `network_2_cache`.`ipNetworkNumber` AS `ipNetworkNumber` + `network_alias_cache`.`cn` AS `alias`, + `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber`, + `network_cache`.`mtime` AS `mtime` FROM `network_cache` - LEFT JOIN `network_1_cache` - ON `network_1_cache`.`network` = `network_cache`.`cn` - LEFT JOIN `network_2_cache` - ON `network_2_cache`.`network` = `network_cache`.`cn` - ''' - - def __init__(self, parameters): - super(NetworkQuery, self).__init__('network', parameters) - - -class Cache(cache.Cache): - - def retrieve(self, parameters): - query = NetworkQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipNetworkNumber', )): - yield row['cn'], row['alias'], row['ipNetworkNumber'] + LEFT JOIN `network_alias_cache` + ON `network_alias_cache`.`network` = `network_cache`.`cn` + LEFT JOIN `network_address_cache` + ON `network_address_cache`.`network` = `network_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `network_cache`.`cn` = ? OR + `network_cache`.`cn` IN ( + SELECT `by_alias`.`network` + FROM `network_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ipNetworkNumber=''' + `network_cache`.`cn` IN ( + SELECT `by_ipNetworkNumber`.`network` + FROM `network_address_cache` `by_ipNetworkNumber` + WHERE `by_ipNetworkNumber`.`ipNetworkNumber` = ?) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, 2) # alias, ipNetworkNumber class NetworkRequest(common.Request): diff --git a/pynslcd/passwd.py b/pynslcd/passwd.py index 7504961..a8e407f 100644 --- a/pynslcd/passwd.py +++ b/pynslcd/passwd.py @@ -47,7 +47,18 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): - pass + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `passwd_cache` + ( `uid` TEXT PRIMARY KEY, + `userPassword` TEXT, + `uidNumber` INTEGER NOT NULL UNIQUE, + `gidNumber` INTEGER NOT NULL, + `gecos` TEXT, + `homeDirectory` TEXT, + `loginShell` TEXT, + `mtime` TIMESTAMP NOT NULL ); + ''' class PasswdRequest(common.Request): @@ -106,7 +117,6 @@ class PasswdByUidRequest(PasswdRequest): self.fp.write_int32(constants.NSLCD_RESULT_END) - class PasswdAllRequest(PasswdRequest): action = constants.NSLCD_ACTION_PASSWD_ALL diff --git a/pynslcd/protocol.py b/pynslcd/protocol.py index cafda9d..1472c04 100644 --- a/pynslcd/protocol.py +++ b/pynslcd/protocol.py @@ -37,10 +37,41 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): - def retrieve(self, parameters): - query = cache.CnAliasedQuery('protocol', parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )): - yield row['cn'], row['alias'], row['ipProtocolNumber'] + tables = ('protocol_cache', 'protocol_alias_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `protocol_cache` + ( `cn` TEXT PRIMARY KEY, + `ipProtocolNumber` INTEGER NOT NULL, + `mtime` TIMESTAMP NOT NULL ); + CREATE TABLE IF NOT EXISTS `protocol_alias_cache` + ( `protocol` TEXT NOT NULL, + `cn` TEXT NOT NULL, + FOREIGN KEY(`protocol`) REFERENCES `protocol_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `protocol_alias_idx` ON `protocol_alias_cache`(`protocol`); + ''' + + retrieve_sql = ''' + SELECT `protocol_cache`.`cn` AS `cn`, `protocol_alias_cache`.`cn` AS `alias`, + `ipProtocolNumber`, `mtime` + FROM `protocol_cache` + LEFT JOIN `protocol_alias_cache` + ON `protocol_alias_cache`.`protocol` = `protocol_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `protocol_cache`.`cn` = ? OR + `protocol_cache`.`cn` IN ( + SELECT `by_alias`.`protocol` + FROM `protocol_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # alias class ProtocolRequest(common.Request): diff --git a/pynslcd/rpc.py b/pynslcd/rpc.py index f20960e..2a241fd 100644 --- a/pynslcd/rpc.py +++ b/pynslcd/rpc.py @@ -37,10 +37,41 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): - def retrieve(self, parameters): - query = cache.CnAliasedQuery('rpc', parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )): - yield row['cn'], row['alias'], row['oncRpcNumber'] + tables = ('rpc_cache', 'rpc_alias_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `rpc_cache` + ( `cn` TEXT PRIMARY KEY, + `oncRpcNumber` INTEGER NOT NULL, + `mtime` TIMESTAMP NOT NULL ); + CREATE TABLE IF NOT EXISTS `rpc_alias_cache` + ( `rpc` TEXT NOT NULL, + `cn` TEXT NOT NULL, + FOREIGN KEY(`rpc`) REFERENCES `rpc_cache`(`cn`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `rpc_alias_idx` ON `rpc_alias_cache`(`rpc`); + ''' + + retrieve_sql = ''' + SELECT `rpc_cache`.`cn` AS `cn`, `rpc_alias_cache`.`cn` AS `alias`, + `oncRpcNumber`, `mtime` + FROM `rpc_cache` + LEFT JOIN `rpc_alias_cache` + ON `rpc_alias_cache`.`rpc` = `rpc_cache`.`cn` + ''' + + retrieve_by = dict( + cn=''' + ( `rpc_cache`.`cn` = ? OR + `rpc_cache`.`cn` IN ( + SELECT `by_alias`.`rpc` + FROM `rpc_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ?)) + ''', + ) + + group_by = (0, ) # cn + group_columns = (1, ) # alias class RpcRequest(common.Request): diff --git a/pynslcd/service.py b/pynslcd/service.py index 19f941d..c27f485 100644 --- a/pynslcd/service.py +++ b/pynslcd/service.py @@ -40,56 +40,73 @@ class Search(search.LDAPSearch): required = ('cn', 'ipServicePort', 'ipServiceProtocol') -class ServiceQuery(cache.CnAliasedQuery): +class Cache(cache.Cache): - sql = ''' - SELECT `service_cache`.*, - `service_1_cache`.`cn` AS `alias` + tables = ('service_cache', 'service_alias_cache') + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `service_cache` + ( `cn` TEXT NOT NULL, + `ipServicePort` INTEGER NOT NULL, + `ipServiceProtocol` TEXT NOT NULL, + `mtime` TIMESTAMP NOT NULL, + UNIQUE (`ipServicePort`, `ipServiceProtocol`) ); + CREATE TABLE IF NOT EXISTS `service_alias_cache` + ( `ipServicePort` INTEGER NOT NULL, + `ipServiceProtocol` TEXT NOT NULL, + `cn` TEXT NOT NULL, + FOREIGN KEY(`ipServicePort`) REFERENCES `service_cache`(`ipServicePort`) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY(`ipServiceProtocol`) REFERENCES `service_cache`(`ipServiceProtocol`) + ON DELETE CASCADE ON UPDATE CASCADE ); + CREATE INDEX IF NOT EXISTS `service_alias_idx1` ON `service_alias_cache`(`ipServicePort`); + CREATE INDEX IF NOT EXISTS `service_alias_idx2` ON `service_alias_cache`(`ipServiceProtocol`); + ''' + + retrieve_sql = ''' + SELECT `service_cache`.`cn` AS `cn`, + `service_alias_cache`.`cn` AS `alias`, + `service_cache`.`ipServicePort`, + `service_cache`.`ipServiceProtocol`, + `mtime` FROM `service_cache` - LEFT JOIN `service_1_cache` - ON `service_1_cache`.`ipServicePort` = `service_cache`.`ipServicePort` - AND `service_1_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` - ''' - - cn_join = ''' - LEFT JOIN `service_1_cache` `cn_alias` - ON `cn_alias`.`ipServicePort` = `service_cache`.`ipServicePort` - AND `cn_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` - ''' - - def __init__(self, parameters): - super(ServiceQuery, self).__init__('service', {}) - for k, v in parameters.items(): - if k == 'cn': - self.add_query(self.cn_join) - self.add_where('(`service_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)', [v, v]) - else: - self.add_where('`service_cache`.`%s` = ?' % k, [v]) - - -class Cache(cache.Cache): + LEFT JOIN `service_alias_cache` + ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort` + AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` + ''' + + retrieve_by = dict( + cn=''' + ( `service_cache`.`cn` = ? OR + 0 < ( + SELECT COUNT(*) + FROM `service_alias_cache` `by_alias` + WHERE `by_alias`.`cn` = ? + AND `by_alias`.`ipServicePort` = `service_cache`.`ipServicePort` + AND `by_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol` + )) + ''', + ) + + group_by = (0, 2, 3) # cn, ipServicePort, ipServiceProtocol + group_columns = (1, ) # alias def store(self, name, aliases, port, protocol): self.con.execute(''' INSERT OR REPLACE INTO `service_cache` VALUES (?, ?, ?, ?) - ''', (name, port, protocol, datetime.datetime.now())) + ''', (name, port, protocol, datetime.datetime.now())) self.con.execute(''' - DELETE FROM `service_1_cache` + DELETE FROM `service_alias_cache` WHERE `ipServicePort` = ? AND `ipServiceProtocol` = ? - ''', (port, protocol)) + ''', (port, protocol)) self.con.executemany(''' - INSERT INTO `service_1_cache` + INSERT INTO `service_alias_cache` VALUES (?, ?, ?) - ''', ((port, protocol, alias) for alias in aliases)) - - def retrieve(self, parameters): - query = ServiceQuery(parameters) - for row in cache.RowGrouper(query.execute(self.con), ('cn', 'ipServicePort', 'ipServiceProtocol'), ('alias', )): - yield row['cn'], row['alias'], row['ipServicePort'], row['ipServiceProtocol'] + ''', ((port, protocol, alias) for alias in aliases)) class ServiceRequest(common.Request): diff --git a/pynslcd/shadow.py b/pynslcd/shadow.py index bedac50..5fd0aa9 100644 --- a/pynslcd/shadow.py +++ b/pynslcd/shadow.py @@ -44,7 +44,20 @@ class Search(search.LDAPSearch): class Cache(cache.Cache): - pass + + create_sql = ''' + CREATE TABLE IF NOT EXISTS `shadow_cache` + ( `uid` TEXT PRIMARY KEY, + `userPassword` TEXT, + `shadowLastChange` INTEGER, + `shadowMin` INTEGER, + `shadowMax` INTEGER, + `shadowWarning` INTEGER, + `shadowInactive` INTEGER, + `shadowExpire` INTEGER, + `shadowFlag` INTEGER, + `mtime` TIMESTAMP NOT NULL ); + ''' class ShadowRequest(common.Request): diff --git a/tests/Makefile.am b/tests/Makefile.am index 184b79e..3c6bfce 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -19,8 +19,10 @@ # 02110-1301 USA TESTS = test_dict test_set test_tio test_expr test_getpeercred test_cfg \ - test_myldap.sh test_common test_nsscmds.sh test_pamcmds.sh \ - test_pycompile.sh + test_myldap.sh test_common test_nsscmds.sh test_pamcmds.sh +if HAVE_PYTHON +TESTS += test_pycompile.sh test_pynslcd_cache.py +endif AM_TESTS_ENVIRONMENT = PYTHON='@PYTHON@'; export PYTHON; @@ -30,7 +32,7 @@ check_PROGRAMS = test_dict test_set test_tio test_expr test_getpeercred \ EXTRA_DIST = nslcd-test.conf test_myldap.sh test_nsscmds.sh test_pamcmds.sh \ test_pycompile.sh in_testenv.sh test_pamcmds.expect \ - usernames.txt + usernames.txt test_pynslcd_cache.py CLEANFILES = $(EXTRA_PROGRAMS) diff --git a/tests/test_pynslcd_cache.py b/tests/test_pynslcd_cache.py new file mode 100755 index 0000000..5c15b01 --- /dev/null +++ b/tests/test_pynslcd_cache.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python + +# test_pynslcd_cache.py - tests for the pynslcd caching functionality +# +# Copyright (C) 2013 Arthur de Jong +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301 USA + +import os +import os.path +import sys +import unittest + +# fix the Python path +sys.path.insert(1, os.path.abspath(os.path.join(sys.path[0], '..', 'pynslcd'))) +sys.path.insert(2, os.path.abspath(os.path.join('..', 'pynslcd'))) + + +# TODO: think about case-sesitivity of cache searches (have tests for that) + + +class TestAlias(unittest.TestCase): + + def setUp(self): + import alias + cache = alias.Cache() + cache.store('alias1', ['member1', 'member2']) + cache.store('alias2', ['member1', 'member3']) + cache.store('alias3', []) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [ + ['alias1', ['member1', 'member2']], + ]) + + def test_by_member(self): + self.assertItemsEqual(self.cache.retrieve(dict(rfc822MailMember='member1')), [ + ['alias1', ['member1', 'member2']], + ['alias2', ['member1', 'member3']], + ]) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + ['alias1', ['member1', 'member2']], + ['alias2', ['member1', 'member3']], + ['alias3', []], + ]) + + +class TestEther(unittest.TestCase): + + def setUp(self): + import ether + cache = ether.Cache() + cache.store('name1', '0:18:8a:54:1a:11') + cache.store('name2', '0:18:8a:54:1a:22') + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='name1')), [ + ['name1', '0:18:8a:54:1a:11'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='name2')), [ + ['name2', '0:18:8a:54:1a:22'], + ]) + + def test_by_ether(self): + # ideally we should also support alternate representations + self.assertItemsEqual(self.cache.retrieve(dict(macAddress='0:18:8a:54:1a:22')), [ + ['name2', '0:18:8a:54:1a:22'], + ]) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + ['name1', '0:18:8a:54:1a:11'], + ['name2', '0:18:8a:54:1a:22'], + ]) + + +class TestGroup(unittest.TestCase): + + def setUp(self): + import group + cache = group.Cache() + cache.store('group1', 'pass1', 10, ['user1', 'user2']) + cache.store('group2', 'pass2', 20, ['user1', 'user2', 'user3']) + cache.store('group3', 'pass3', 30, []) + cache.store('group4', 'pass4', 40, ['user2', ]) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='group1')), [ + ['group1', 'pass1', 10, ['user1', 'user2']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='group3')), [ + ['group3', 'pass3', 30, []], + ]) + + def test_by_gid(self): + self.assertItemsEqual(self.cache.retrieve(dict(gidNumber=10)), [ + ['group1', 'pass1', 10, ['user1', 'user2']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(gidNumber=40)), [ + ['group4', 'pass4', 40, ['user2']], + ]) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + ['group1', 'pass1', 10, ['user1', 'user2']], + ['group2', 'pass2', 20, ['user1', 'user2', 'user3']], + ['group3', 'pass3', 30, []], + ['group4', 'pass4', 40, ['user2']], + ]) + + def test_bymember(self): + self.assertItemsEqual(self.cache.retrieve(dict(memberUid='user1')), [ + ['group1', 'pass1', 10, ['user1', 'user2']], + ['group2', 'pass2', 20, ['user1', 'user2', 'user3']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(memberUid='user2')), [ + ['group1', 'pass1', 10, ['user1', 'user2']], + ['group2', 'pass2', 20, ['user1', 'user2', 'user3']], + ['group4', 'pass4', 40, ['user2']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(memberUid='user3')), [ + ['group2', 'pass2', 20, ['user1', 'user2', 'user3']], + ]) + + +class TestHost(unittest.TestCase): + + def setUp(self): + import host + cache = host.Cache() + cache.store('hostname1', [], ['127.0.0.1', ]) + cache.store('hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='hostname1')), [ + ['hostname1', [], ['127.0.0.1']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='hostname2')), [ + ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + + def test_by_alias(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [ + ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias2')), [ + ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + + def test_by_address(self): + self.assertItemsEqual(self.cache.retrieve(dict(ipHostNumber='127.0.0.3')), [ + ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + + +class TestNetgroup(unittest.TestCase): + + def setUp(self): + import netgroup + cache = netgroup.Cache() + cache.store('netgroup1', ['(host1, user1,)', '(host1, user2,)', '(host2, user1,)'], ['netgroup2', ]) + cache.store('netgroup2', ['(host3, user1,)', '(host3, user3,)'], []) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='netgroup1')), [ + ['netgroup1', ['(host1, user1,)', '(host1, user2,)', '(host2, user1,)'], ['netgroup2', ]], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='netgroup2')), [ + ['netgroup2', ['(host3, user1,)', '(host3, user3,)'], []], + ]) + + +class TestNetwork(unittest.TestCase): + + def setUp(self): + import network + cache = network.Cache() + cache.store('networkname1', [], ['127.0.0.1', ]) + cache.store('networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='networkname1')), [ + ['networkname1', [], ['127.0.0.1']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='networkname2')), [ + ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + + def test_by_alias(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [ + ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias2')), [ + ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + + def test_by_address(self): + self.assertItemsEqual(self.cache.retrieve(dict(ipNetworkNumber='127.0.0.3')), [ + ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']], + ]) + + +class TestPasswd(unittest.TestCase): + + def setUp(self): + import passwd + cache = passwd.Cache() + cache.store('name', 'passwd', 100, 200, 'gecos', '/home/user', '/bin/bash') + cache.store('name2', 'passwd2', 101, 202, 'gecos2', '/home/user2', '/bin/bash') + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(uid='name')), [ + [u'name', u'passwd', 100, 200, u'gecos', u'/home/user', u'/bin/bash'], + ]) + + def test_by_unknown_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(uid='notfound')), []) + + def test_by_number(self): + self.assertItemsEqual(self.cache.retrieve(dict(uidNumber=100)), [ + [u'name', u'passwd', 100, 200, u'gecos', u'/home/user', u'/bin/bash'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(uidNumber=101)), [ + ['name2', 'passwd2', 101, 202, 'gecos2', '/home/user2', '/bin/bash'], + ]) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + [u'name', u'passwd', 100, 200, u'gecos', u'/home/user', u'/bin/bash'], + [u'name2', u'passwd2', 101, 202, u'gecos2', u'/home/user2', u'/bin/bash'], + ]) + + +class TestProtocol(unittest.TestCase): + + def setUp(self): + import protocol + cache = protocol.Cache() + cache.store('protocol1', ['alias1', 'alias2'], 100) + cache.store('protocol2', ['alias3', ], 200) + cache.store('protocol3', [], 300) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='protocol1')), [ + ['protocol1', ['alias1', 'alias2'], 100], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='protocol2')), [ + ['protocol2', ['alias3', ], 200], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='protocol3')), [ + ['protocol3', [], 300], + ]) + + def test_by_unknown_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='notfound')), []) + + def test_by_number(self): + self.assertItemsEqual(self.cache.retrieve(dict(ipProtocolNumber=100)), [ + ['protocol1', ['alias1', 'alias2'], 100], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(ipProtocolNumber=200)), [ + ['protocol2', ['alias3', ], 200], + ]) + + def test_by_alias(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [ + ['protocol1', ['alias1', 'alias2'], 100], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias3')), [ + ['protocol2', ['alias3', ], 200], + ]) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + ['protocol1', ['alias1', 'alias2'], 100], + ['protocol2', ['alias3'], 200], + ['protocol3', [], 300], + ]) + + +class TestRpc(unittest.TestCase): + + def setUp(self): + import rpc + cache = rpc.Cache() + cache.store('rpc1', ['alias1', 'alias2'], 100) + cache.store('rpc2', ['alias3', ], 200) + cache.store('rpc3', [], 300) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='rpc1')), [ + ['rpc1', ['alias1', 'alias2'], 100], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='rpc2')), [ + ['rpc2', ['alias3', ], 200], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='rpc3')), [ + ['rpc3', [], 300], + ]) + + def test_by_unknown_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='notfound')), []) + + def test_by_number(self): + self.assertItemsEqual(self.cache.retrieve(dict(oncRpcNumber=100)), [ + ['rpc1', ['alias1', 'alias2'], 100], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(oncRpcNumber=200)), [ + ['rpc2', ['alias3', ], 200], + ]) + + def test_by_alias(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [ + ['rpc1', ['alias1', 'alias2'], 100], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias3')), [ + ['rpc2', ['alias3', ], 200], + ]) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + ['rpc1', ['alias1', 'alias2'], 100], + ['rpc2', ['alias3'], 200], + ['rpc3', [], 300], + ]) + + +class TestService(unittest.TestCase): + + def setUp(self): + import service + cache = service.Cache() + cache.store('service1', ['alias1', 'alias2'], 100, 'tcp') + cache.store('service1', ['alias1', 'alias2'], 100, 'udp') + cache.store('service2', ['alias3', ], 200, 'udp') + cache.store('service3', [], 300, 'udp') + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='service1')), [ + ['service1', ['alias1', 'alias2'], 100, 'tcp'], + ['service1', ['alias1', 'alias2'], 100, 'udp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='service2')), [ + ['service2', ['alias3', ], 200, 'udp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='service3')), [ + ['service3', [], 300, 'udp'], + ]) + + def test_by_name_and_protocol(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='service1', ipServiceProtocol='udp')), [ + ['service1', ['alias1', 'alias2'], 100, 'udp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='service1', ipServiceProtocol='tcp')), [ + ['service1', ['alias1', 'alias2'], 100, 'tcp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='service2', ipServiceProtocol='udp')), [ + ['service2', ['alias3', ], 200, 'udp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='service2', ipServiceProtocol='tcp')), [ + ]) + + def test_by_unknown_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='notfound')), []) + + def test_by_number(self): + self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=100)), [ + ['service1', ['alias1', 'alias2'], 100, 'tcp'], + ['service1', ['alias1', 'alias2'], 100, 'udp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=200)), [ + ['service2', ['alias3', ], 200, 'udp'], + ]) + + def test_by_number_and_protocol(self): + self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=100, ipServiceProtocol='udp')), [ + ['service1', ['alias1', 'alias2'], 100, 'udp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=100, ipServiceProtocol='tcp')), [ + ['service1', ['alias1', 'alias2'], 100, 'tcp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=200, ipServiceProtocol='udp')), [ + ['service2', ['alias3', ], 200, 'udp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=200, ipServiceProtocol='tcp')), [ + ]) + + def test_by_alias(self): + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [ + ['service1', ['alias1', 'alias2'], 100, 'udp'], + ['service1', ['alias1', 'alias2'], 100, 'tcp'], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(cn='alias3')), [ + ['service2', ['alias3', ], 200, 'udp'], + ]) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + ['service1', ['alias1', 'alias2'], 100, 'tcp'], + ['service1', ['alias1', 'alias2'], 100, 'udp'], + ['service2', ['alias3', ], 200, 'udp'], + ['service3', [], 300, 'udp'], + ]) + + +class Testshadow(unittest.TestCase): + + def setUp(self): + import shadow + cache = shadow.Cache() + cache.store('name', 'passwd', 15639, 0, 7, -1, -1, -1, 0) + cache.store('name2', 'passwd2', 15639, 0, 7, -1, -1, -1, 0) + self.cache = cache + + def test_by_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(uid='name')), [ + [u'name', u'passwd', 15639, 0, 7, -1, -1, -1, 0], + ]) + self.assertItemsEqual(self.cache.retrieve(dict(uid='name2')), [ + [u'name2', u'passwd2', 15639, 0, 7, -1, -1, -1, 0], + ]) + + def test_by_unknown_name(self): + self.assertItemsEqual(self.cache.retrieve(dict(uid='notfound')), []) + + def test_all(self): + self.assertItemsEqual(self.cache.retrieve({}), [ + [u'name', u'passwd', 15639, 0, 7, -1, -1, -1, 0], + [u'name2', u'passwd2', 15639, 0, 7, -1, -1, -1, 0], + ]) + + +if __name__ == '__main__': + unittest.main() |