summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArthur de Jong <arthur@arthurdejong.org>2013-08-17 12:32:07 +0200
committerArthur de Jong <arthur@arthurdejong.org>2013-08-17 12:32:07 +0200
commit8a3f0f51b2406e6ee9537fdc96cadc0d3fa2194c (patch)
tree49608eb4f63bbe85ff8c730b90828430f4cf9567
parent84d22e608b03c154d11e54ff34d7b87bf1d78cfa (diff)
parenta066bcb17e1b99a42a5834d1ace6feba7c9b60b7 (diff)
Improvements to pynslcd caching functionality
This fixes most of the existing caching functionality. Cache expiry, negative hits and entries going away remain to be implemented.
-rw-r--r--configure.ac1
-rw-r--r--pynslcd/alias.py41
-rw-r--r--pynslcd/cache.py362
-rw-r--r--pynslcd/ether.py9
-rw-r--r--pynslcd/group.py45
-rw-r--r--pynslcd/host.py69
-rw-r--r--pynslcd/netgroup.py36
-rw-r--r--pynslcd/network.py69
-rw-r--r--pynslcd/passwd.py14
-rw-r--r--pynslcd/protocol.py39
-rw-r--r--pynslcd/rpc.py39
-rw-r--r--pynslcd/service.py89
-rw-r--r--pynslcd/shadow.py15
-rw-r--r--tests/Makefile.am8
-rwxr-xr-xtests/test_pynslcd_cache.py459
15 files changed, 941 insertions, 354 deletions
diff --git a/configure.ac b/configure.ac
index 9774e3a..a586173 100644
--- a/configure.ac
+++ b/configure.ac
@@ -61,6 +61,7 @@ AM_PROG_CC_C_O
AC_USE_SYSTEM_EXTENSIONS
AC_PROG_LN_S
AM_PATH_PYTHON(2.5,, [:])
+AM_CONDITIONAL([HAVE_PYTHON], [test "$PYTHON" != ":"])
AM_PROG_AR
# checks for tool to convert docbook to man
diff --git a/pynslcd/alias.py b/pynslcd/alias.py
index 46c4d6b..371ac2e 100644
--- a/pynslcd/alias.py
+++ b/pynslcd/alias.py
@@ -37,19 +37,40 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
+ tables = ('alias_cache', 'alias_member_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `alias_cache`
+ ( `cn` TEXT PRIMARY KEY COLLATE NOCASE,
+ `mtime` TIMESTAMP NOT NULL );
+ CREATE TABLE IF NOT EXISTS `alias_member_cache`
+ ( `alias` TEXT NOT NULL COLLATE NOCASE,
+ `rfc822MailMember` TEXT NOT NULL,
+ FOREIGN KEY(`alias`) REFERENCES `alias_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `alias_member_idx` ON `alias_member_cache`(`alias`);
+ '''
+
retrieve_sql = '''
SELECT `alias_cache`.`cn` AS `cn`,
- `alias_1_cache`.`rfc822MailMember` AS `rfc822MailMember`
+ `alias_member_cache`.`rfc822MailMember` AS `rfc822MailMember`,
+ `alias_cache`.`mtime` AS `mtime`
FROM `alias_cache`
- LEFT JOIN `alias_1_cache`
- ON `alias_1_cache`.`alias` = `alias_cache`.`cn`
- '''
-
- def retrieve(self, parameters):
- query = cache.Query(self.retrieve_sql, parameters)
- # return results, returning the members as a list
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('rfc822MailMember', )):
- yield row['cn'], row['rfc822MailMember']
+ LEFT JOIN `alias_member_cache`
+ ON `alias_member_cache`.`alias` = `alias_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ rfc822MailMember='''
+ `cn` IN (
+ SELECT `a`.`alias`
+ FROM `alias_member_cache` `a`
+ WHERE `a`.`rfc822MailMember` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, ) # rfc822MailMember
class AliasRequest(common.Request):
diff --git a/pynslcd/cache.py b/pynslcd/cache.py
index 7089d41..3974fef 100644
--- a/pynslcd/cache.py
+++ b/pynslcd/cache.py
@@ -27,184 +27,53 @@ import sqlite3
# TODO: probably create a config table
-
-
-# FIXME: store the cache in the right place and make it configurable
-filename = '/tmp/cache.sqlite'
-dirname = os.path.dirname(filename)
-if not os.path.isdir(dirname):
- os.mkdir(dirname)
-con = sqlite3.connect(filename,
- detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False)
-con.row_factory = sqlite3.Row
-
# FIXME: have some way to remove stale entries from the cache if all items from LDAP are queried (perhas use TTL from all request)
-# set up the database
-con.executescript('''
-
- -- store temporary tables in memory
- PRAGMA temp_store = MEMORY;
-
- -- disable sync() on database (corruption on disk failure)
- PRAGMA synchronous = OFF;
-
- -- put journal in memory (corruption if crash during transaction)
- PRAGMA journal_mode = MEMORY;
-
- -- tables for alias cache
- CREATE TABLE IF NOT EXISTS `alias_cache`
- ( `cn` TEXT PRIMARY KEY COLLATE NOCASE,
- `mtime` TIMESTAMP NOT NULL );
- CREATE TABLE IF NOT EXISTS `alias_1_cache`
- ( `alias` TEXT NOT NULL COLLATE NOCASE,
- `rfc822MailMember` TEXT NOT NULL,
- FOREIGN KEY(`alias`) REFERENCES `alias_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `alias_1_idx` ON `alias_1_cache`(`alias`);
-
- -- table for ethernet cache
- CREATE TABLE IF NOT EXISTS `ether_cache`
- ( `cn` TEXT NOT NULL COLLATE NOCASE,
- `macAddress` TEXT NOT NULL COLLATE NOCASE,
- `mtime` TIMESTAMP NOT NULL,
- UNIQUE (`cn`, `macAddress`) );
-
- -- table for group cache
- CREATE TABLE IF NOT EXISTS `group_cache`
- ( `cn` TEXT PRIMARY KEY,
- `userPassword` TEXT,
- `gidNumber` INTEGER NOT NULL UNIQUE,
- `mtime` TIMESTAMP NOT NULL );
- CREATE TABLE IF NOT EXISTS `group_3_cache`
- ( `group` TEXT NOT NULL,
- `memberUid` TEXT NOT NULL,
- FOREIGN KEY(`group`) REFERENCES `group_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `group_3_idx` ON `group_3_cache`(`group`);
-
- -- tables for host cache
- CREATE TABLE IF NOT EXISTS `host_cache`
- ( `cn` TEXT PRIMARY KEY COLLATE NOCASE,
- `mtime` TIMESTAMP NOT NULL );
- CREATE TABLE IF NOT EXISTS `host_1_cache`
- ( `host` TEXT NOT NULL COLLATE NOCASE,
- `cn` TEXT NOT NULL COLLATE NOCASE,
- FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `host_1_idx` ON `host_1_cache`(`host`);
- CREATE TABLE IF NOT EXISTS `host_2_cache`
- ( `host` TEXT NOT NULL COLLATE NOCASE,
- `ipHostNumber` TEXT NOT NULL,
- FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `host_2_idx` ON `host_2_cache`(`host`);
- -- FIXME: this does not work as entries are never removed from the cache
- CREATE TABLE IF NOT EXISTS `netgroup_cache`
- ( `cn` TEXT NOT NULL,
- `member` TEXT NOT NULL,
- `mtime` TIMESTAMP NOT NULL,
- UNIQUE (`cn`, `member`) );
+class regroup(object):
- -- tables for network cache
- CREATE TABLE IF NOT EXISTS `network_cache`
- ( `cn` TEXT PRIMARY KEY COLLATE NOCASE,
- `mtime` TIMESTAMP NOT NULL );
- CREATE TABLE IF NOT EXISTS `network_1_cache`
- ( `network` TEXT NOT NULL COLLATE NOCASE,
- `cn` TEXT NOT NULL COLLATE NOCASE,
- FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `network_1_idx` ON `network_1_cache`(`network`);
- CREATE TABLE IF NOT EXISTS `network_2_cache`
- ( `network` TEXT NOT NULL,
- `ipNetworkNumber` TEXT NOT NULL,
- FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `network_2_idx` ON `network_2_cache`(`network`);
+ def __init__(self, results, group_by=None, group_column=None):
+ """Regroup the results in the group column by the key columns."""
+ self.group_by = tuple(group_by)
+ self.group_column = group_column
+ self.it = iter(results)
+ self.tgtkey = self.currkey = self.currvalue = object()
- -- table for passwd cache
- CREATE TABLE IF NOT EXISTS `passwd_cache`
- ( `uid` TEXT PRIMARY KEY,
- `userPassword` TEXT,
- `uidNumber` INTEGER NOT NULL UNIQUE,
- `gidNumber` INTEGER NOT NULL,
- `gecos` TEXT,
- `homeDirectory` TEXT,
- `loginShell` TEXT,
- `mtime` TIMESTAMP NOT NULL );
-
- -- table for protocol cache
- CREATE TABLE IF NOT EXISTS `protocol_cache`
- ( `cn` TEXT PRIMARY KEY,
- `ipProtocolNumber` INTEGER NOT NULL,
- `mtime` TIMESTAMP NOT NULL );
- CREATE TABLE IF NOT EXISTS `protocol_1_cache`
- ( `protocol` TEXT NOT NULL,
- `cn` TEXT NOT NULL,
- FOREIGN KEY(`protocol`) REFERENCES `protocol_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `protocol_1_idx` ON `protocol_1_cache`(`protocol`);
-
- -- table for rpc cache
- CREATE TABLE IF NOT EXISTS `rpc_cache`
- ( `cn` TEXT PRIMARY KEY,
- `oncRpcNumber` INTEGER NOT NULL,
- `mtime` TIMESTAMP NOT NULL );
- CREATE TABLE IF NOT EXISTS `rpc_1_cache`
- ( `rpc` TEXT NOT NULL,
- `cn` TEXT NOT NULL,
- FOREIGN KEY(`rpc`) REFERENCES `rpc_cache`(`cn`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `rpc_1_idx` ON `rpc_1_cache`(`rpc`);
-
- -- tables for service cache
- CREATE TABLE IF NOT EXISTS `service_cache`
- ( `cn` TEXT NOT NULL,
- `ipServicePort` INTEGER NOT NULL,
- `ipServiceProtocol` TEXT NOT NULL,
- `mtime` TIMESTAMP NOT NULL,
- UNIQUE (`ipServicePort`, `ipServiceProtocol`) );
- CREATE TABLE IF NOT EXISTS `service_1_cache`
- ( `ipServicePort` INTEGER NOT NULL,
- `ipServiceProtocol` TEXT NOT NULL,
- `cn` TEXT NOT NULL,
- FOREIGN KEY(`ipServicePort`) REFERENCES `service_cache`(`ipServicePort`)
- ON DELETE CASCADE ON UPDATE CASCADE,
- FOREIGN KEY(`ipServiceProtocol`) REFERENCES `service_cache`(`ipServiceProtocol`)
- ON DELETE CASCADE ON UPDATE CASCADE );
- CREATE INDEX IF NOT EXISTS `service_1_idx1` ON `service_1_cache`(`ipServicePort`);
- CREATE INDEX IF NOT EXISTS `service_1_idx2` ON `service_1_cache`(`ipServiceProtocol`);
+ def keyfunc(self, row):
+ return tuple(row[x] for x in self.group_by)
- -- table for shadow cache
- CREATE TABLE IF NOT EXISTS `shadow_cache`
- ( `uid` TEXT PRIMARY KEY,
- `userPassword` TEXT,
- `shadowLastChange` INTEGER,
- `shadowMin` INTEGER,
- `shadowMax` INTEGER,
- `shadowWarning` INTEGER,
- `shadowInactive` INTEGER,
- `shadowExpire` INTEGER,
- `shadowFlag` INTEGER,
- `mtime` TIMESTAMP NOT NULL );
+ def __iter__(self):
+ return self
- ''')
+ def next(self):
+ # find a start row
+ while self.currkey == self.tgtkey:
+ self.currvalue = next(self.it) # Exit on StopIteration
+ self.currkey = self.keyfunc(self.currvalue)
+ self.tgtkey = self.currkey
+ # turn the result row into a list of columns
+ row = list(self.currvalue)
+ # replace the group column
+ row[self.group_column] = list(self._grouper(self.tgtkey))
+ return row
+
+ def _grouper(self, tgtkey):
+ """Generate the group columns."""
+ while self.currkey == tgtkey:
+ value = self.currvalue[self.group_column]
+ if value is not None:
+ yield value
+ self.currvalue = next(self.it) # Exit on StopIteration
+ self.currkey = self.keyfunc(self.currvalue)
class Query(object):
+ """Helper class to build an SQL query for the cache."""
- def __init__(self, query, parameters=None):
+ def __init__(self, query):
self.query = query
self.wheres = []
self.parameters = []
- if parameters:
- for k, v in parameters.items():
- self.add_where('`%s` = ?' % k, [v])
-
- def add_query(self, query):
- self.query += ' ' + query
def add_where(self, where, parameters):
self.wheres.append(where)
@@ -214,100 +83,113 @@ class Query(object):
query = self.query
if self.wheres:
query += ' WHERE ' + ' AND '.join(self.wheres)
- c = con.cursor()
- return c.execute(query, self.parameters)
-
-
-class CnAliasedQuery(Query):
-
- sql = '''
- SELECT `%(table)s_cache`.*,
- `%(table)s_1_cache`.`cn` AS `alias`
- FROM `%(table)s_cache`
- LEFT JOIN `%(table)s_1_cache`
- ON `%(table)s_1_cache`.`%(table)s` = `%(table)s_cache`.`cn`
- '''
-
- cn_join = '''
- LEFT JOIN `%(table)s_1_cache` `cn_alias`
- ON `cn_alias`.`%(table)s` = `%(table)s_cache`.`cn`
- '''
-
- def __init__(self, table, parameters):
- args = dict(table=table)
- super(CnAliasedQuery, self).__init__(self.sql % args)
- for k, v in parameters.items():
- if k == 'cn':
- self.add_query(self.cn_join % args)
- self.add_where('(`%(table)s_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)' % args, [v, v])
- else:
- self.add_where('`%s` = ?' % k, [v])
-
-
-class RowGrouper(object):
- """Pass in query results and group the results by a certain specified
- list of columns."""
-
- def __init__(self, results, groupby, columns):
- self.groupby = groupby
- self.columns = columns
- self.results = itertools.groupby(results, key=self.keyfunc)
-
- def __iter__(self):
- return self
-
- def keyfunc(self, row):
- return tuple(row[x] for x in self.groupby)
-
- def next(self):
- groupcols, rows = self.results.next()
- tmp = dict((x, list()) for x in self.columns)
- for row in rows:
- for col in self.columns:
- if row[col] is not None:
- tmp[col].append(row[col])
- result = dict(row)
- result.update(tmp)
- return result
+ cursor = con.cursor()
+ return cursor.execute(query, self.parameters)
class Cache(object):
+ """The description of the cache."""
+
+ retrieve_sql = None
+ retrieve_by = dict()
+ group_by = ()
+ group_columns = ()
def __init__(self):
- self.con = con
- self.table = sys.modules[self.__module__].__name__
+ self.con = _get_connection()
+ self.db = sys.modules[self.__module__].__name__
+ if not hasattr(self, 'tables'):
+ self.tables = ['%s_cache' % self.db]
+ self.create()
+
+ def create(self):
+ """Create the needed tables if neccesary."""
+ self.con.executescript(self.create_sql)
def store(self, *values):
- """Store the values in the cache for the specified table."""
+ """Store the values in the cache for the specified table.
+ The order of the values is the order returned by the Reques.convert()
+ function."""
+ # split the values into simple (flat) values and one-to-many values
simple_values = []
- multi_values = {}
- for n, v in enumerate(values):
+ multi_values = []
+ for v in values:
if isinstance(v, (list, tuple, set)):
- multi_values[n] = v
+ multi_values.append(v)
else:
simple_values.append(v)
+ # insert the simple values
simple_values.append(datetime.datetime.now())
args = ', '.join(len(simple_values) * ('?', ))
- con.execute('''
- INSERT OR REPLACE INTO %s_cache
+ self.con.execute('''
+ INSERT OR REPLACE INTO %s
VALUES
(%s)
- ''' % (self.table, args), simple_values)
- for n, vlist in multi_values.items():
- con.execute('''
- DELETE FROM %s_%d_cache
+ ''' % (self.tables[0], args), simple_values)
+ # insert the one-to-many values
+ for n, vlist in enumerate(multi_values):
+ self.con.execute('''
+ DELETE FROM %s
WHERE `%s` = ?
- ''' % (self.table, n, self.table), (values[0], ))
- con.executemany('''
- INSERT INTO %s_%d_cache
+ ''' % (self.tables[n + 1], self.db), (values[0], ))
+ self.con.executemany('''
+ INSERT INTO %s
VALUES
(?, ?)
- ''' % (self.table, n), ((values[0], x) for x in vlist))
+ ''' % (self.tables[n + 1]), ((values[0], x) for x in vlist))
def retrieve(self, parameters):
- """Retrieve all items from the cache based on the parameters supplied."""
- query = Query('''
+ """Retrieve all items from the cache based on the parameters
+ supplied."""
+ query = Query(self.retrieve_sql or '''
SELECT *
- FROM %s_cache
- ''' % self.table, parameters)
- return (list(x)[:-1] for x in query.execute(self.con))
+ FROM %s
+ ''' % self.tables[0])
+ if parameters:
+ for k, v in parameters.items():
+ where = self.retrieve_by.get(k, '`%s`.`%s` = ?' % (self.tables[0], k))
+ query.add_where(where, where.count('?') * [v])
+ # group by
+ # FIXME: find a nice way to turn group_by and group_columns into names
+ results = query.execute(self.con)
+ group_by = list(self.group_by + self.group_columns)
+ for column in self.group_columns[::-1]:
+ group_by.pop()
+ results = regroup(results, group_by, column)
+ # strip the mtime from the results
+ return (list(x)[:-1] for x in results)
+
+ def __enter__(self):
+ return self.con.__enter__();
+
+ def __exit__(self, *args):
+ return self.con.__exit__(*args);
+
+
+# the connection to the sqlite database
+_connection = None
+
+
+# FIXME: make tread safe (is this needed the way the caches are initialised?)
+def _get_connection():
+ global _connection
+ if _connection is None:
+ filename = '/tmp/pynslcd_cache.sqlite'
+ dirname = os.path.dirname(filename)
+ if not os.path.isdir(dirname):
+ os.mkdir(dirname)
+ connection = sqlite3.connect(
+ filename, detect_types=sqlite3.PARSE_DECLTYPES,
+ check_same_thread=False)
+ connection.row_factory = sqlite3.Row
+ # initialise connection properties
+ connection.executescript('''
+ -- store temporary tables in memory
+ PRAGMA temp_store = MEMORY;
+ -- disable sync() on database (corruption on disk failure)
+ PRAGMA synchronous = OFF;
+ -- put journal in memory (corruption if crash during transaction)
+ PRAGMA journal_mode = MEMORY;
+ ''')
+ _connection = connection
+ return _connection
diff --git a/pynslcd/ether.py b/pynslcd/ether.py
index d5d8c06..e5060ca 100644
--- a/pynslcd/ether.py
+++ b/pynslcd/ether.py
@@ -59,7 +59,14 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
- pass
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `ether_cache`
+ ( `cn` TEXT NOT NULL COLLATE NOCASE,
+ `macAddress` TEXT NOT NULL COLLATE NOCASE,
+ `mtime` TIMESTAMP NOT NULL,
+ UNIQUE (`cn`, `macAddress`) );
+ '''
class EtherRequest(common.Request):
diff --git a/pynslcd/group.py b/pynslcd/group.py
index 2868d96..10e3423 100644
--- a/pynslcd/group.py
+++ b/pynslcd/group.py
@@ -75,20 +75,41 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
+ tables = ('group_cache', 'group_member_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `group_cache`
+ ( `cn` TEXT PRIMARY KEY,
+ `userPassword` TEXT,
+ `gidNumber` INTEGER NOT NULL UNIQUE,
+ `mtime` TIMESTAMP NOT NULL );
+ CREATE TABLE IF NOT EXISTS `group_member_cache`
+ ( `group` TEXT NOT NULL,
+ `memberUid` TEXT NOT NULL,
+ FOREIGN KEY(`group`) REFERENCES `group_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `group_member_idx` ON `group_member_cache`(`group`);
+ '''
+
retrieve_sql = '''
- SELECT `cn`, `userPassword`, `gidNumber`, `memberUid`
+ SELECT `group_cache`.`cn` AS `cn`, `userPassword`, `gidNumber`,
+ `memberUid`, `mtime`
FROM `group_cache`
- LEFT JOIN `group_3_cache`
- ON `group_3_cache`.`group` = `group_cache`.`cn`
- '''
-
- def retrieve(self, parameters):
- query = cache.Query(self.retrieve_sql, parameters)
- # return results returning the members as a set
- q = itertools.groupby(query.execute(self.con),
- key=lambda x: (x['cn'], x['userPassword'], x['gidNumber']))
- for k, v in q:
- yield k + (set(x['memberUid'] for x in v if x['memberUid'] is not None), )
+ LEFT JOIN `group_member_cache`
+ ON `group_member_cache`.`group` = `group_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ memberUid='''
+ `cn` IN (
+ SELECT `a`.`group`
+ FROM `group_member_cache` `a`
+ WHERE `a`.`memberUid` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (3, ) # memberUid
class GroupRequest(common.Request):
diff --git a/pynslcd/host.py b/pynslcd/host.py
index ffd9588..04f5337 100644
--- a/pynslcd/host.py
+++ b/pynslcd/host.py
@@ -34,29 +34,58 @@ class Search(search.LDAPSearch):
required = ('cn', )
-class HostQuery(cache.CnAliasedQuery):
+class Cache(cache.Cache):
- sql = '''
+ tables = ('host_cache', 'host_alias_cache', 'host_address_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `host_cache`
+ ( `cn` TEXT PRIMARY KEY COLLATE NOCASE,
+ `mtime` TIMESTAMP NOT NULL );
+ CREATE TABLE IF NOT EXISTS `host_alias_cache`
+ ( `host` TEXT NOT NULL COLLATE NOCASE,
+ `cn` TEXT NOT NULL COLLATE NOCASE,
+ FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `host_alias_idx` ON `host_alias_cache`(`host`);
+ CREATE TABLE IF NOT EXISTS `host_address_cache`
+ ( `host` TEXT NOT NULL COLLATE NOCASE,
+ `ipHostNumber` TEXT NOT NULL,
+ FOREIGN KEY(`host`) REFERENCES `host_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `host_address_idx` ON `host_address_cache`(`host`);
+ '''
+
+ retrieve_sql = '''
SELECT `host_cache`.`cn` AS `cn`,
- `host_1_cache`.`cn` AS `alias`,
- `host_2_cache`.`ipHostNumber` AS `ipHostNumber`
+ `host_alias_cache`.`cn` AS `alias`,
+ `host_address_cache`.`ipHostNumber` AS `ipHostNumber`,
+ `host_cache`.`mtime` AS `mtime`
FROM `host_cache`
- LEFT JOIN `host_1_cache`
- ON `host_1_cache`.`host` = `host_cache`.`cn`
- LEFT JOIN `host_2_cache`
- ON `host_2_cache`.`host` = `host_cache`.`cn`
- '''
-
- def __init__(self, parameters):
- super(HostQuery, self).__init__('host', parameters)
-
-
-class Cache(cache.Cache):
-
- def retrieve(self, parameters):
- query = HostQuery(parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipHostNumber', )):
- yield row['cn'], row['alias'], row['ipHostNumber']
+ LEFT JOIN `host_alias_cache`
+ ON `host_alias_cache`.`host` = `host_cache`.`cn`
+ LEFT JOIN `host_address_cache`
+ ON `host_address_cache`.`host` = `host_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `host_cache`.`cn` = ? OR
+ `host_cache`.`cn` IN (
+ SELECT `by_alias`.`host`
+ FROM `host_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ ipHostNumber='''
+ `host_cache`.`cn` IN (
+ SELECT `by_ipHostNumber`.`host`
+ FROM `host_address_cache` `by_ipHostNumber`
+ WHERE `by_ipHostNumber`.`ipHostNumber` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, 2) # alias, ipHostNumber
class HostRequest(common.Request):
diff --git a/pynslcd/netgroup.py b/pynslcd/netgroup.py
index 20f8779..d86e38c 100644
--- a/pynslcd/netgroup.py
+++ b/pynslcd/netgroup.py
@@ -42,7 +42,41 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
- pass
+
+ tables = ('netgroup_cache', 'netgroup_triple_cache', 'netgroup_member_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `netgroup_cache`
+ ( `cn` TEXT PRIMARY KEY COLLATE NOCASE,
+ `mtime` TIMESTAMP NOT NULL );
+ CREATE TABLE IF NOT EXISTS `netgroup_triple_cache`
+ ( `netgroup` TEXT NOT NULL COLLATE NOCASE,
+ `nisNetgroupTriple` TEXT NOT NULL COLLATE NOCASE,
+ FOREIGN KEY(`netgroup`) REFERENCES `netgroup_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `netgroup_triple_idx` ON `netgroup_triple_cache`(`netgroup`);
+ CREATE TABLE IF NOT EXISTS `netgroup_member_cache`
+ ( `netgroup` TEXT NOT NULL COLLATE NOCASE,
+ `memberNisNetgroup` TEXT NOT NULL,
+ FOREIGN KEY(`netgroup`) REFERENCES `netgroup_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `netgroup_membe_idx` ON `netgroup_member_cache`(`netgroup`);
+ '''
+
+ retrieve_sql = '''
+ SELECT `netgroup_cache`.`cn` AS `cn`,
+ `netgroup_triple_cache`.`nisNetgroupTriple` AS `nisNetgroupTriple`,
+ `netgroup_member_cache`.`memberNisNetgroup` AS `memberNisNetgroup`,
+ `netgroup_cache`.`mtime` AS `mtime`
+ FROM `netgroup_cache`
+ LEFT JOIN `netgroup_triple_cache`
+ ON `netgroup_triple_cache`.`netgroup` = `netgroup_cache`.`cn`
+ LEFT JOIN `netgroup_member_cache`
+ ON `netgroup_member_cache`.`netgroup` = `netgroup_cache`.`cn`
+ '''
+
+ group_by = (0, ) # cn
+ group_columns = (1, 2) # nisNetgroupTriple, memberNisNetgroup
class NetgroupRequest(common.Request):
diff --git a/pynslcd/network.py b/pynslcd/network.py
index dc91d68..01bf6c2 100644
--- a/pynslcd/network.py
+++ b/pynslcd/network.py
@@ -35,29 +35,58 @@ class Search(search.LDAPSearch):
required = ('cn', )
-class NetworkQuery(cache.CnAliasedQuery):
+class Cache(cache.Cache):
- sql = '''
+ tables = ('network_cache', 'network_alias_cache', 'network_address_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `network_cache`
+ ( `cn` TEXT PRIMARY KEY COLLATE NOCASE,
+ `mtime` TIMESTAMP NOT NULL );
+ CREATE TABLE IF NOT EXISTS `network_alias_cache`
+ ( `network` TEXT NOT NULL COLLATE NOCASE,
+ `cn` TEXT NOT NULL COLLATE NOCASE,
+ FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `network_alias_idx` ON `network_alias_cache`(`network`);
+ CREATE TABLE IF NOT EXISTS `network_address_cache`
+ ( `network` TEXT NOT NULL COLLATE NOCASE,
+ `ipNetworkNumber` TEXT NOT NULL,
+ FOREIGN KEY(`network`) REFERENCES `network_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `network_address_idx` ON `network_address_cache`(`network`);
+ '''
+
+ retrieve_sql = '''
SELECT `network_cache`.`cn` AS `cn`,
- `network_1_cache`.`cn` AS `alias`,
- `network_2_cache`.`ipNetworkNumber` AS `ipNetworkNumber`
+ `network_alias_cache`.`cn` AS `alias`,
+ `network_address_cache`.`ipNetworkNumber` AS `ipNetworkNumber`,
+ `network_cache`.`mtime` AS `mtime`
FROM `network_cache`
- LEFT JOIN `network_1_cache`
- ON `network_1_cache`.`network` = `network_cache`.`cn`
- LEFT JOIN `network_2_cache`
- ON `network_2_cache`.`network` = `network_cache`.`cn`
- '''
-
- def __init__(self, parameters):
- super(NetworkQuery, self).__init__('network', parameters)
-
-
-class Cache(cache.Cache):
-
- def retrieve(self, parameters):
- query = NetworkQuery(parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', 'ipNetworkNumber', )):
- yield row['cn'], row['alias'], row['ipNetworkNumber']
+ LEFT JOIN `network_alias_cache`
+ ON `network_alias_cache`.`network` = `network_cache`.`cn`
+ LEFT JOIN `network_address_cache`
+ ON `network_address_cache`.`network` = `network_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `network_cache`.`cn` = ? OR
+ `network_cache`.`cn` IN (
+ SELECT `by_alias`.`network`
+ FROM `network_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ ipNetworkNumber='''
+ `network_cache`.`cn` IN (
+ SELECT `by_ipNetworkNumber`.`network`
+ FROM `network_address_cache` `by_ipNetworkNumber`
+ WHERE `by_ipNetworkNumber`.`ipNetworkNumber` = ?)
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, 2) # alias, ipNetworkNumber
class NetworkRequest(common.Request):
diff --git a/pynslcd/passwd.py b/pynslcd/passwd.py
index 7504961..a8e407f 100644
--- a/pynslcd/passwd.py
+++ b/pynslcd/passwd.py
@@ -47,7 +47,18 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
- pass
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `passwd_cache`
+ ( `uid` TEXT PRIMARY KEY,
+ `userPassword` TEXT,
+ `uidNumber` INTEGER NOT NULL UNIQUE,
+ `gidNumber` INTEGER NOT NULL,
+ `gecos` TEXT,
+ `homeDirectory` TEXT,
+ `loginShell` TEXT,
+ `mtime` TIMESTAMP NOT NULL );
+ '''
class PasswdRequest(common.Request):
@@ -106,7 +117,6 @@ class PasswdByUidRequest(PasswdRequest):
self.fp.write_int32(constants.NSLCD_RESULT_END)
-
class PasswdAllRequest(PasswdRequest):
action = constants.NSLCD_ACTION_PASSWD_ALL
diff --git a/pynslcd/protocol.py b/pynslcd/protocol.py
index cafda9d..1472c04 100644
--- a/pynslcd/protocol.py
+++ b/pynslcd/protocol.py
@@ -37,10 +37,41 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
- def retrieve(self, parameters):
- query = cache.CnAliasedQuery('protocol', parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )):
- yield row['cn'], row['alias'], row['ipProtocolNumber']
+ tables = ('protocol_cache', 'protocol_alias_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `protocol_cache`
+ ( `cn` TEXT PRIMARY KEY,
+ `ipProtocolNumber` INTEGER NOT NULL,
+ `mtime` TIMESTAMP NOT NULL );
+ CREATE TABLE IF NOT EXISTS `protocol_alias_cache`
+ ( `protocol` TEXT NOT NULL,
+ `cn` TEXT NOT NULL,
+ FOREIGN KEY(`protocol`) REFERENCES `protocol_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `protocol_alias_idx` ON `protocol_alias_cache`(`protocol`);
+ '''
+
+ retrieve_sql = '''
+ SELECT `protocol_cache`.`cn` AS `cn`, `protocol_alias_cache`.`cn` AS `alias`,
+ `ipProtocolNumber`, `mtime`
+ FROM `protocol_cache`
+ LEFT JOIN `protocol_alias_cache`
+ ON `protocol_alias_cache`.`protocol` = `protocol_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `protocol_cache`.`cn` = ? OR
+ `protocol_cache`.`cn` IN (
+ SELECT `by_alias`.`protocol`
+ FROM `protocol_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, ) # alias
class ProtocolRequest(common.Request):
diff --git a/pynslcd/rpc.py b/pynslcd/rpc.py
index f20960e..2a241fd 100644
--- a/pynslcd/rpc.py
+++ b/pynslcd/rpc.py
@@ -37,10 +37,41 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
- def retrieve(self, parameters):
- query = cache.CnAliasedQuery('rpc', parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', ), ('alias', )):
- yield row['cn'], row['alias'], row['oncRpcNumber']
+ tables = ('rpc_cache', 'rpc_alias_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `rpc_cache`
+ ( `cn` TEXT PRIMARY KEY,
+ `oncRpcNumber` INTEGER NOT NULL,
+ `mtime` TIMESTAMP NOT NULL );
+ CREATE TABLE IF NOT EXISTS `rpc_alias_cache`
+ ( `rpc` TEXT NOT NULL,
+ `cn` TEXT NOT NULL,
+ FOREIGN KEY(`rpc`) REFERENCES `rpc_cache`(`cn`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `rpc_alias_idx` ON `rpc_alias_cache`(`rpc`);
+ '''
+
+ retrieve_sql = '''
+ SELECT `rpc_cache`.`cn` AS `cn`, `rpc_alias_cache`.`cn` AS `alias`,
+ `oncRpcNumber`, `mtime`
+ FROM `rpc_cache`
+ LEFT JOIN `rpc_alias_cache`
+ ON `rpc_alias_cache`.`rpc` = `rpc_cache`.`cn`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `rpc_cache`.`cn` = ? OR
+ `rpc_cache`.`cn` IN (
+ SELECT `by_alias`.`rpc`
+ FROM `rpc_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?))
+ ''',
+ )
+
+ group_by = (0, ) # cn
+ group_columns = (1, ) # alias
class RpcRequest(common.Request):
diff --git a/pynslcd/service.py b/pynslcd/service.py
index 19f941d..c27f485 100644
--- a/pynslcd/service.py
+++ b/pynslcd/service.py
@@ -40,56 +40,73 @@ class Search(search.LDAPSearch):
required = ('cn', 'ipServicePort', 'ipServiceProtocol')
-class ServiceQuery(cache.CnAliasedQuery):
+class Cache(cache.Cache):
- sql = '''
- SELECT `service_cache`.*,
- `service_1_cache`.`cn` AS `alias`
+ tables = ('service_cache', 'service_alias_cache')
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `service_cache`
+ ( `cn` TEXT NOT NULL,
+ `ipServicePort` INTEGER NOT NULL,
+ `ipServiceProtocol` TEXT NOT NULL,
+ `mtime` TIMESTAMP NOT NULL,
+ UNIQUE (`ipServicePort`, `ipServiceProtocol`) );
+ CREATE TABLE IF NOT EXISTS `service_alias_cache`
+ ( `ipServicePort` INTEGER NOT NULL,
+ `ipServiceProtocol` TEXT NOT NULL,
+ `cn` TEXT NOT NULL,
+ FOREIGN KEY(`ipServicePort`) REFERENCES `service_cache`(`ipServicePort`)
+ ON DELETE CASCADE ON UPDATE CASCADE,
+ FOREIGN KEY(`ipServiceProtocol`) REFERENCES `service_cache`(`ipServiceProtocol`)
+ ON DELETE CASCADE ON UPDATE CASCADE );
+ CREATE INDEX IF NOT EXISTS `service_alias_idx1` ON `service_alias_cache`(`ipServicePort`);
+ CREATE INDEX IF NOT EXISTS `service_alias_idx2` ON `service_alias_cache`(`ipServiceProtocol`);
+ '''
+
+ retrieve_sql = '''
+ SELECT `service_cache`.`cn` AS `cn`,
+ `service_alias_cache`.`cn` AS `alias`,
+ `service_cache`.`ipServicePort`,
+ `service_cache`.`ipServiceProtocol`,
+ `mtime`
FROM `service_cache`
- LEFT JOIN `service_1_cache`
- ON `service_1_cache`.`ipServicePort` = `service_cache`.`ipServicePort`
- AND `service_1_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
- '''
-
- cn_join = '''
- LEFT JOIN `service_1_cache` `cn_alias`
- ON `cn_alias`.`ipServicePort` = `service_cache`.`ipServicePort`
- AND `cn_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
- '''
-
- def __init__(self, parameters):
- super(ServiceQuery, self).__init__('service', {})
- for k, v in parameters.items():
- if k == 'cn':
- self.add_query(self.cn_join)
- self.add_where('(`service_cache`.`cn` = ? OR `cn_alias`.`cn` = ?)', [v, v])
- else:
- self.add_where('`service_cache`.`%s` = ?' % k, [v])
-
-
-class Cache(cache.Cache):
+ LEFT JOIN `service_alias_cache`
+ ON `service_alias_cache`.`ipServicePort` = `service_cache`.`ipServicePort`
+ AND `service_alias_cache`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
+ '''
+
+ retrieve_by = dict(
+ cn='''
+ ( `service_cache`.`cn` = ? OR
+ 0 < (
+ SELECT COUNT(*)
+ FROM `service_alias_cache` `by_alias`
+ WHERE `by_alias`.`cn` = ?
+ AND `by_alias`.`ipServicePort` = `service_cache`.`ipServicePort`
+ AND `by_alias`.`ipServiceProtocol` = `service_cache`.`ipServiceProtocol`
+ ))
+ ''',
+ )
+
+ group_by = (0, 2, 3) # cn, ipServicePort, ipServiceProtocol
+ group_columns = (1, ) # alias
def store(self, name, aliases, port, protocol):
self.con.execute('''
INSERT OR REPLACE INTO `service_cache`
VALUES
(?, ?, ?, ?)
- ''', (name, port, protocol, datetime.datetime.now()))
+ ''', (name, port, protocol, datetime.datetime.now()))
self.con.execute('''
- DELETE FROM `service_1_cache`
+ DELETE FROM `service_alias_cache`
WHERE `ipServicePort` = ?
AND `ipServiceProtocol` = ?
- ''', (port, protocol))
+ ''', (port, protocol))
self.con.executemany('''
- INSERT INTO `service_1_cache`
+ INSERT INTO `service_alias_cache`
VALUES
(?, ?, ?)
- ''', ((port, protocol, alias) for alias in aliases))
-
- def retrieve(self, parameters):
- query = ServiceQuery(parameters)
- for row in cache.RowGrouper(query.execute(self.con), ('cn', 'ipServicePort', 'ipServiceProtocol'), ('alias', )):
- yield row['cn'], row['alias'], row['ipServicePort'], row['ipServiceProtocol']
+ ''', ((port, protocol, alias) for alias in aliases))
class ServiceRequest(common.Request):
diff --git a/pynslcd/shadow.py b/pynslcd/shadow.py
index bedac50..5fd0aa9 100644
--- a/pynslcd/shadow.py
+++ b/pynslcd/shadow.py
@@ -44,7 +44,20 @@ class Search(search.LDAPSearch):
class Cache(cache.Cache):
- pass
+
+ create_sql = '''
+ CREATE TABLE IF NOT EXISTS `shadow_cache`
+ ( `uid` TEXT PRIMARY KEY,
+ `userPassword` TEXT,
+ `shadowLastChange` INTEGER,
+ `shadowMin` INTEGER,
+ `shadowMax` INTEGER,
+ `shadowWarning` INTEGER,
+ `shadowInactive` INTEGER,
+ `shadowExpire` INTEGER,
+ `shadowFlag` INTEGER,
+ `mtime` TIMESTAMP NOT NULL );
+ '''
class ShadowRequest(common.Request):
diff --git a/tests/Makefile.am b/tests/Makefile.am
index 184b79e..3c6bfce 100644
--- a/tests/Makefile.am
+++ b/tests/Makefile.am
@@ -19,8 +19,10 @@
# 02110-1301 USA
TESTS = test_dict test_set test_tio test_expr test_getpeercred test_cfg \
- test_myldap.sh test_common test_nsscmds.sh test_pamcmds.sh \
- test_pycompile.sh
+ test_myldap.sh test_common test_nsscmds.sh test_pamcmds.sh
+if HAVE_PYTHON
+TESTS += test_pycompile.sh test_pynslcd_cache.py
+endif
AM_TESTS_ENVIRONMENT = PYTHON='@PYTHON@'; export PYTHON;
@@ -30,7 +32,7 @@ check_PROGRAMS = test_dict test_set test_tio test_expr test_getpeercred \
EXTRA_DIST = nslcd-test.conf test_myldap.sh test_nsscmds.sh test_pamcmds.sh \
test_pycompile.sh in_testenv.sh test_pamcmds.expect \
- usernames.txt
+ usernames.txt test_pynslcd_cache.py
CLEANFILES = $(EXTRA_PROGRAMS)
diff --git a/tests/test_pynslcd_cache.py b/tests/test_pynslcd_cache.py
new file mode 100755
index 0000000..5c15b01
--- /dev/null
+++ b/tests/test_pynslcd_cache.py
@@ -0,0 +1,459 @@
+#!/usr/bin/env python
+
+# test_pynslcd_cache.py - tests for the pynslcd caching functionality
+#
+# Copyright (C) 2013 Arthur de Jong
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301 USA
+
+import os
+import os.path
+import sys
+import unittest
+
+# fix the Python path
+sys.path.insert(1, os.path.abspath(os.path.join(sys.path[0], '..', 'pynslcd')))
+sys.path.insert(2, os.path.abspath(os.path.join('..', 'pynslcd')))
+
+
+# TODO: think about case-sesitivity of cache searches (have tests for that)
+
+
+class TestAlias(unittest.TestCase):
+
+ def setUp(self):
+ import alias
+ cache = alias.Cache()
+ cache.store('alias1', ['member1', 'member2'])
+ cache.store('alias2', ['member1', 'member3'])
+ cache.store('alias3', [])
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [
+ ['alias1', ['member1', 'member2']],
+ ])
+
+ def test_by_member(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(rfc822MailMember='member1')), [
+ ['alias1', ['member1', 'member2']],
+ ['alias2', ['member1', 'member3']],
+ ])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ ['alias1', ['member1', 'member2']],
+ ['alias2', ['member1', 'member3']],
+ ['alias3', []],
+ ])
+
+
+class TestEther(unittest.TestCase):
+
+ def setUp(self):
+ import ether
+ cache = ether.Cache()
+ cache.store('name1', '0:18:8a:54:1a:11')
+ cache.store('name2', '0:18:8a:54:1a:22')
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='name1')), [
+ ['name1', '0:18:8a:54:1a:11'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='name2')), [
+ ['name2', '0:18:8a:54:1a:22'],
+ ])
+
+ def test_by_ether(self):
+ # ideally we should also support alternate representations
+ self.assertItemsEqual(self.cache.retrieve(dict(macAddress='0:18:8a:54:1a:22')), [
+ ['name2', '0:18:8a:54:1a:22'],
+ ])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ ['name1', '0:18:8a:54:1a:11'],
+ ['name2', '0:18:8a:54:1a:22'],
+ ])
+
+
+class TestGroup(unittest.TestCase):
+
+ def setUp(self):
+ import group
+ cache = group.Cache()
+ cache.store('group1', 'pass1', 10, ['user1', 'user2'])
+ cache.store('group2', 'pass2', 20, ['user1', 'user2', 'user3'])
+ cache.store('group3', 'pass3', 30, [])
+ cache.store('group4', 'pass4', 40, ['user2', ])
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='group1')), [
+ ['group1', 'pass1', 10, ['user1', 'user2']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='group3')), [
+ ['group3', 'pass3', 30, []],
+ ])
+
+ def test_by_gid(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(gidNumber=10)), [
+ ['group1', 'pass1', 10, ['user1', 'user2']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(gidNumber=40)), [
+ ['group4', 'pass4', 40, ['user2']],
+ ])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ ['group1', 'pass1', 10, ['user1', 'user2']],
+ ['group2', 'pass2', 20, ['user1', 'user2', 'user3']],
+ ['group3', 'pass3', 30, []],
+ ['group4', 'pass4', 40, ['user2']],
+ ])
+
+ def test_bymember(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(memberUid='user1')), [
+ ['group1', 'pass1', 10, ['user1', 'user2']],
+ ['group2', 'pass2', 20, ['user1', 'user2', 'user3']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(memberUid='user2')), [
+ ['group1', 'pass1', 10, ['user1', 'user2']],
+ ['group2', 'pass2', 20, ['user1', 'user2', 'user3']],
+ ['group4', 'pass4', 40, ['user2']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(memberUid='user3')), [
+ ['group2', 'pass2', 20, ['user1', 'user2', 'user3']],
+ ])
+
+
+class TestHost(unittest.TestCase):
+
+ def setUp(self):
+ import host
+ cache = host.Cache()
+ cache.store('hostname1', [], ['127.0.0.1', ])
+ cache.store('hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3'])
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='hostname1')), [
+ ['hostname1', [], ['127.0.0.1']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='hostname2')), [
+ ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+
+ def test_by_alias(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [
+ ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias2')), [
+ ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+
+ def test_by_address(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(ipHostNumber='127.0.0.3')), [
+ ['hostname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+
+
+class TestNetgroup(unittest.TestCase):
+
+ def setUp(self):
+ import netgroup
+ cache = netgroup.Cache()
+ cache.store('netgroup1', ['(host1, user1,)', '(host1, user2,)', '(host2, user1,)'], ['netgroup2', ])
+ cache.store('netgroup2', ['(host3, user1,)', '(host3, user3,)'], [])
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='netgroup1')), [
+ ['netgroup1', ['(host1, user1,)', '(host1, user2,)', '(host2, user1,)'], ['netgroup2', ]],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='netgroup2')), [
+ ['netgroup2', ['(host3, user1,)', '(host3, user3,)'], []],
+ ])
+
+
+class TestNetwork(unittest.TestCase):
+
+ def setUp(self):
+ import network
+ cache = network.Cache()
+ cache.store('networkname1', [], ['127.0.0.1', ])
+ cache.store('networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3'])
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='networkname1')), [
+ ['networkname1', [], ['127.0.0.1']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='networkname2')), [
+ ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+
+ def test_by_alias(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [
+ ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias2')), [
+ ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+
+ def test_by_address(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(ipNetworkNumber='127.0.0.3')), [
+ ['networkname2', ['alias1', 'alias2'], ['127.0.0.2', '127.0.0.3']],
+ ])
+
+
+class TestPasswd(unittest.TestCase):
+
+ def setUp(self):
+ import passwd
+ cache = passwd.Cache()
+ cache.store('name', 'passwd', 100, 200, 'gecos', '/home/user', '/bin/bash')
+ cache.store('name2', 'passwd2', 101, 202, 'gecos2', '/home/user2', '/bin/bash')
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(uid='name')), [
+ [u'name', u'passwd', 100, 200, u'gecos', u'/home/user', u'/bin/bash'],
+ ])
+
+ def test_by_unknown_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(uid='notfound')), [])
+
+ def test_by_number(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(uidNumber=100)), [
+ [u'name', u'passwd', 100, 200, u'gecos', u'/home/user', u'/bin/bash'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(uidNumber=101)), [
+ ['name2', 'passwd2', 101, 202, 'gecos2', '/home/user2', '/bin/bash'],
+ ])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ [u'name', u'passwd', 100, 200, u'gecos', u'/home/user', u'/bin/bash'],
+ [u'name2', u'passwd2', 101, 202, u'gecos2', u'/home/user2', u'/bin/bash'],
+ ])
+
+
+class TestProtocol(unittest.TestCase):
+
+ def setUp(self):
+ import protocol
+ cache = protocol.Cache()
+ cache.store('protocol1', ['alias1', 'alias2'], 100)
+ cache.store('protocol2', ['alias3', ], 200)
+ cache.store('protocol3', [], 300)
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='protocol1')), [
+ ['protocol1', ['alias1', 'alias2'], 100],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='protocol2')), [
+ ['protocol2', ['alias3', ], 200],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='protocol3')), [
+ ['protocol3', [], 300],
+ ])
+
+ def test_by_unknown_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='notfound')), [])
+
+ def test_by_number(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(ipProtocolNumber=100)), [
+ ['protocol1', ['alias1', 'alias2'], 100],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(ipProtocolNumber=200)), [
+ ['protocol2', ['alias3', ], 200],
+ ])
+
+ def test_by_alias(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [
+ ['protocol1', ['alias1', 'alias2'], 100],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias3')), [
+ ['protocol2', ['alias3', ], 200],
+ ])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ ['protocol1', ['alias1', 'alias2'], 100],
+ ['protocol2', ['alias3'], 200],
+ ['protocol3', [], 300],
+ ])
+
+
+class TestRpc(unittest.TestCase):
+
+ def setUp(self):
+ import rpc
+ cache = rpc.Cache()
+ cache.store('rpc1', ['alias1', 'alias2'], 100)
+ cache.store('rpc2', ['alias3', ], 200)
+ cache.store('rpc3', [], 300)
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='rpc1')), [
+ ['rpc1', ['alias1', 'alias2'], 100],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='rpc2')), [
+ ['rpc2', ['alias3', ], 200],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='rpc3')), [
+ ['rpc3', [], 300],
+ ])
+
+ def test_by_unknown_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='notfound')), [])
+
+ def test_by_number(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(oncRpcNumber=100)), [
+ ['rpc1', ['alias1', 'alias2'], 100],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(oncRpcNumber=200)), [
+ ['rpc2', ['alias3', ], 200],
+ ])
+
+ def test_by_alias(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [
+ ['rpc1', ['alias1', 'alias2'], 100],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias3')), [
+ ['rpc2', ['alias3', ], 200],
+ ])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ ['rpc1', ['alias1', 'alias2'], 100],
+ ['rpc2', ['alias3'], 200],
+ ['rpc3', [], 300],
+ ])
+
+
+class TestService(unittest.TestCase):
+
+ def setUp(self):
+ import service
+ cache = service.Cache()
+ cache.store('service1', ['alias1', 'alias2'], 100, 'tcp')
+ cache.store('service1', ['alias1', 'alias2'], 100, 'udp')
+ cache.store('service2', ['alias3', ], 200, 'udp')
+ cache.store('service3', [], 300, 'udp')
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='service1')), [
+ ['service1', ['alias1', 'alias2'], 100, 'tcp'],
+ ['service1', ['alias1', 'alias2'], 100, 'udp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='service2')), [
+ ['service2', ['alias3', ], 200, 'udp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='service3')), [
+ ['service3', [], 300, 'udp'],
+ ])
+
+ def test_by_name_and_protocol(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='service1', ipServiceProtocol='udp')), [
+ ['service1', ['alias1', 'alias2'], 100, 'udp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='service1', ipServiceProtocol='tcp')), [
+ ['service1', ['alias1', 'alias2'], 100, 'tcp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='service2', ipServiceProtocol='udp')), [
+ ['service2', ['alias3', ], 200, 'udp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='service2', ipServiceProtocol='tcp')), [
+ ])
+
+ def test_by_unknown_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='notfound')), [])
+
+ def test_by_number(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=100)), [
+ ['service1', ['alias1', 'alias2'], 100, 'tcp'],
+ ['service1', ['alias1', 'alias2'], 100, 'udp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=200)), [
+ ['service2', ['alias3', ], 200, 'udp'],
+ ])
+
+ def test_by_number_and_protocol(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=100, ipServiceProtocol='udp')), [
+ ['service1', ['alias1', 'alias2'], 100, 'udp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=100, ipServiceProtocol='tcp')), [
+ ['service1', ['alias1', 'alias2'], 100, 'tcp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=200, ipServiceProtocol='udp')), [
+ ['service2', ['alias3', ], 200, 'udp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(ipServicePort=200, ipServiceProtocol='tcp')), [
+ ])
+
+ def test_by_alias(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias1')), [
+ ['service1', ['alias1', 'alias2'], 100, 'udp'],
+ ['service1', ['alias1', 'alias2'], 100, 'tcp'],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(cn='alias3')), [
+ ['service2', ['alias3', ], 200, 'udp'],
+ ])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ ['service1', ['alias1', 'alias2'], 100, 'tcp'],
+ ['service1', ['alias1', 'alias2'], 100, 'udp'],
+ ['service2', ['alias3', ], 200, 'udp'],
+ ['service3', [], 300, 'udp'],
+ ])
+
+
+class Testshadow(unittest.TestCase):
+
+ def setUp(self):
+ import shadow
+ cache = shadow.Cache()
+ cache.store('name', 'passwd', 15639, 0, 7, -1, -1, -1, 0)
+ cache.store('name2', 'passwd2', 15639, 0, 7, -1, -1, -1, 0)
+ self.cache = cache
+
+ def test_by_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(uid='name')), [
+ [u'name', u'passwd', 15639, 0, 7, -1, -1, -1, 0],
+ ])
+ self.assertItemsEqual(self.cache.retrieve(dict(uid='name2')), [
+ [u'name2', u'passwd2', 15639, 0, 7, -1, -1, -1, 0],
+ ])
+
+ def test_by_unknown_name(self):
+ self.assertItemsEqual(self.cache.retrieve(dict(uid='notfound')), [])
+
+ def test_all(self):
+ self.assertItemsEqual(self.cache.retrieve({}), [
+ [u'name', u'passwd', 15639, 0, 7, -1, -1, -1, 0],
+ [u'name2', u'passwd2', 15639, 0, 7, -1, -1, -1, 0],
+ ])
+
+
+if __name__ == '__main__':
+ unittest.main()