summaryrefslogtreecommitdiff
path: root/tools/notsd-fixup--includes
diff options
context:
space:
mode:
Diffstat (limited to 'tools/notsd-fixup--includes')
-rwxr-xr-xtools/notsd-fixup--includes305
1 files changed, 305 insertions, 0 deletions
diff --git a/tools/notsd-fixup--includes b/tools/notsd-fixup--includes
new file mode 100755
index 0000000000..a636c78be6
--- /dev/null
+++ b/tools/notsd-fixup--includes
@@ -0,0 +1,305 @@
+#!/usr/bin/env python3
+
+# If you are thinking "this file looks gross!", it is. It
+# started out as a set of Bash one-liners. Which got turned
+# into a script. Which grew somewhat organically. Not huge,
+# but given that it started as some one liners, that's not a
+# very pretty several hunderd lines. Then got fairly litterally
+# translated into this, for speed. So yes, it is gross.
+# Rewrites welcome; just don't introduce any behavioral changes
+# (easy since `tools/notsd-move` runs it on the entire repo and
+# puts the results in git history).
+
+import atexit
+import filecmp
+import json
+import os
+import re
+import shlex
+import subprocess
+import sys
+
+################################################################
+# Everything else in this program is just fluff and bookkeeping
+# around around calling classify().
+
+# Return a tuple of (class/group, path); which is a class that
+# the header path belongs to, and a normalized path for it.
+#
+# There are a fixed number of classes that it may put a header
+# in; in order of most-public to most-private:
+#
+# system
+# linux
+# public
+# protected
+# private
+def classify(expensive, current_file, path):
+ if re.fullmatch('.*/include(-staging)?/.*/.*', current_file):
+ lib = os.path.basename(os.path.dirname(current_file))
+ if path.startswith(lib+'/'):
+ path = re.sub('^'+lib+'/', path)
+ if path.startswith('linux/'):
+ return 'linux', path
+ elif expensive.exists(os.path.join(os.path.dirname(current_file), path)):
+ return 'private', path
+ elif not path.startswith('systemd/') and path != 'libudev.h' and expensive.cpp(path):
+ return 'system', path
+ else:
+ if path.endswith('-to-name.h') or path.endswith('-from-name.h'):
+ base = re.fullmatch('(.*)-(to|from)-name\.h', os.path.basename(path)).group(1)
+ d={
+ 'dns_type' : 'src/grp-resolve/systemd-resolved',
+ 'keyboard-keys' : 'src/grp-udev/libudev-core',
+ 'af' : 'src/libsystemd-basic/src',
+ 'arphrd' : 'src/libsystemd-basic/src',
+ 'cap' : 'src/libsystemd-basic/src',
+ 'errno' : 'src/libsystemd-basic/src',
+ 'audit_type' : 'src/libsystemd/src/sd-journal',
+ }
+ file = os.path.join(d[base], os.path.basename(path))
+ if current_file.startswith(d[base]):
+ return 'private', os.path.basename(file)
+ elif '/include/' in file:
+ return 'protected', re.sub('.*/include/', '', file)
+ else:
+ return 'protected', os.path.basename(file)
+ elif path in [ 'asm/sgidefs.h', 'dbus/dbus.h', 'efi.h', 'efilib.h', 'gio/gio.h', 'glib.h', 'libmount.h' ]:
+ return 'system', path
+ elif os.path.basename(path) == 'util.h':
+ if '/systemd-boot/' in current_file:
+ return 'private', 'util.h'
+ else:
+ return 'protected', 'systemd-basic/util.h'
+ else:
+ find = expensive.find(os.path.basename(path))
+ if len(find) == 1:
+ file = find[0]
+ if '/src/' in file:
+ if os.path.dirname(current_file) == os.path.dirname(file):
+ return 'private', os.path.basename(file)
+ else:
+ return 'protected', re.sub('.*/src/', '', file)
+ elif ('/libsystemd/include/' in file) or ('/libudev/include/' in file):
+ return 'public', re.sub('.*/include/', '', file)
+ elif '/include/' in file:
+ return 'protected', re.sub('.*/include/', '', file)
+ elif '/include-staging/' in file:
+ return 'protected', re.sub('.*/include-staging/', '', file)
+ else:
+ if os.path.dirname(current_file) == os.path.dirname(file):
+ return 'private', os.path.basename(file)
+ else:
+ return 'protected', os.path.basename(file)
+ else:
+ sys.exit('Cannot figure out: {0}'.format(path))
+
+################################################################
+# Cache expensive things
+
+class Cache:
+ def __init__(self, filename):
+ self.cache = {
+ 'find': None,
+ 'cpp': {}
+ }
+ self.dirty = True
+
+ if os.path.isfile(filename):
+ with open(filename) as file:
+ self.cache = json.load(file)
+ self.dirty = False
+
+ def save(self, filename):
+ if self.dirty:
+ with open(filename, 'w') as file:
+ json.dump(self.cache, file)
+
+ def real_cpp(path):
+ # `cpp -include "$path" <<<'' &>/dev/null`
+ print(' -> cpp({0})'.format(path), file=sys.stderr)
+ with subprocess.Popen(['cpp', '-include', path],
+ stdin=subprocess.PIPE,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL) as proc:
+ proc.stdin.close()
+ return proc.wait() == 0
+
+ def real_find():
+ # This can probably be done with os.walk or something,
+ # but since it is only called once, it isn't a good
+ # place to start optimizing.
+ #
+ # `find src -name '*.h' \( -type l -printf 'l %p\n' -o -type f -printf 'f %p\n' \)`
+ print(' -> find()', file=sys.stderr)
+ ret = {}
+ with subprocess.Popen(['find', 'src', '-name', '*.h', '(', '-type', 'l', '-printf', 'l %p\n', '-o', '-type', 'f', '-printf', 'f %p\n', ')'],
+ stdin=subprocess.DEVNULL,
+ stdout=subprocess.PIPE, universal_newlines=True,
+ stderr=subprocess.DEVNULL) as proc:
+ for line in proc.stdout:
+ t, p = line.rstrip('\n').split(' ', 1)
+ ret[p]=t
+ return ret
+
+ def cpp(self, path):
+ # `cpp -include "$path" <<<'' &>/dev/null`
+ if path not in self.cache['cpp']:
+ self.cache['cpp'][path] = Cache.real_cpp(path)
+ self.dirty = True
+ return self.cache['cpp'][path]
+
+ def exists(self, path):
+ # `test -f "$path"`
+ if not self.cache['find']:
+ self.cache['find'] = Cache.real_find()
+ self.dirty = True
+ return path in self.cache['find']
+
+ def find(self, name):
+ # `find src -type f -name "$name"`
+ if not self.cache['find']:
+ self.cache['find'] = Cache.real_find()
+ self.dirty = True
+ return [p for p in self.cache['find'].keys() if self.cache['find'][p]=='f' and os.path.basename(p) == name]
+
+################################################################
+# Data structure for storing a chunk of `#include` lines.
+
+class IncludeSection:
+ def __init__(self):
+ self.trailing_nl = ''
+ self.system = []
+ self.linux = []
+ self.public = []
+ self.protected = []
+ self.typedef = []
+ self.typedef_last = True
+ self.private = []
+ def print(self, file=sys.stdout):
+ b=''
+ if len(self.system) > 0:
+ for line in sorted(set(self.system)):
+ print(line, file=file)
+ b='\n'
+ if len(self.linux) > 0:
+ print(b, end='', file=file)
+ for line in self.linux:
+ print(line, file=file)
+ b='\n'
+ if len(self.public) > 0:
+ print(b, end='', file=file)
+ for line in sorted(set(self.public)):
+ print(line, file=file)
+ b='\n'
+ if len(self.protected) > 0:
+ print(b, end='', file=file)
+ for line in sorted(set(self.protected)):
+ print(line, file=file)
+ b='\n'
+ if len(self.typedef) > 0 and not self.typedef_last:
+ print(b, end='', file=file)
+ for line in sorted(set(self.typedef)):
+ print(line, file=file)
+ b='\n'
+ if len(self.private) > 0:
+ print(b, end='', file=file)
+ for line in sorted(set(self.private)):
+ print(line, file=file)
+ b='\n'
+ if len(self.typedef) > 0 and self.typedef_last:
+ print(b, end='', file=file)
+ for line in self.typedef:
+ print(line, file=file)
+ print(self.trailing_nl, end='', file=file)
+ def add(self, group, path, extra):
+ if group == 'system':
+ self.system.append('#include <{0}>{1}'.format(path, extra))
+ elif group == 'linux':
+ self.linux.append('#include <{0}>{1}'.format(path, extra))
+ elif group == 'public':
+ self.public.append('#include <{0}>{1}'.format(path, extra))
+ elif group == 'protected':
+ self.protected.append('#include "{0}"{1}'.format(path, extra))
+ elif group == 'private':
+ if len(self.typedef) > 0:
+ self.typedef_last = False
+ self.private.append('#include "{0}"{1}'.format(path, extra))
+ else:
+ sys.exit('panic: unrecognized line class: {0}'.format(group))
+
+################################################################
+# The main program loop
+
+class Parser:
+ def __init__(self, cache, ifilename, ofilename):
+ self.cache = cache
+ self.ifilename = os.path.normpath(ifilename)
+ self.ofilename = ofilename
+
+ self.includes = None
+ self.phase = self.phase0
+
+ def phase0(self, line, ofile):
+ self.phase = self.phase0
+
+ if re.fullmatch('#include.*|typedef .*;', line):
+ self.includes = IncludeSection()
+ self.phase1(line, ofile)
+ else:
+ print(line, file=ofile)
+
+ def phase1(self, line, ofile):
+ self.phase = self.phase1
+
+ if line == '':
+ self.includes.trailing_nl += '\n'
+ elif line.startswith('#include'):
+ self.includes.trailing_nl = ''
+ match = re.fullmatch('^#include [<"]([^">]*)[">](.*)', line)
+ if match:
+ group, path = classify(self.cache, self.ifilename, match.group(1))
+ self.includes.add(group, path, match.group(2))
+ else:
+ sys.exit('panic: malformed #include line')
+ elif re.fullmatch('typedef .*;', line):
+ self.includes.trailing_nl = ''
+ self.includes.typedef.append(line)
+ else:
+ self.includes.print(file=ofile)
+ self.includes = None
+ self.phase0(line, ofile)
+
+ def run(self):
+ print(' => {0} {1}'.format(
+ shlex.quote(__file__),
+ shlex.quote(self.ifilename),
+ ), file=sys.stderr)
+ with open(self.ofilename, 'w') as ofile:
+ with open(self.ifilename) as ifile:
+ for line in ifile:
+ self.phase(line.rstrip('\n'), ofile)
+ if self.includes:
+ self.includes.print(file=ofile)
+
+def main(argv):
+ cache = Cache(__file__+'.cache')
+ tmpfilename = ''
+ def cleanup():
+ if tmpfilename != '':
+ try:
+ os.unlink(tmpfilename)
+ except FileNotFoundError:
+ pass
+ atexit.register(cleanup)
+ for filename in argv[1:]:
+ tmpfilename = os.path.join(os.path.dirname(filename), '.tmp.'+os.path.basename(filename)+'.tmp')
+ Parser(cache, filename, tmpfilename).run()
+ if not filecmp.cmp(filename, tmpfilename):
+ os.rename(tmpfilename, filename)
+ cleanup()
+ tmpfilename = ''
+ cache.save(__file__+'.cache')
+
+if __name__ == '__main__':
+ main(sys.argv)