summaryrefslogtreecommitdiff
path: root/tools/notsd-fixup--includes
diff options
context:
space:
mode:
Diffstat (limited to 'tools/notsd-fixup--includes')
-rwxr-xr-xtools/notsd-fixup--includes46
1 files changed, 31 insertions, 15 deletions
diff --git a/tools/notsd-fixup--includes b/tools/notsd-fixup--includes
index 196fd488a3..d16a3e6500 100755
--- a/tools/notsd-fixup--includes
+++ b/tools/notsd-fixup--includes
@@ -10,6 +10,8 @@
# (easy since `tools/notsd-move` runs it on the entire repo and
# puts the results in git history).
+import atexit
+import filecmp
import json
import os
import re
@@ -225,7 +227,7 @@ class IncludeSection:
################################################################
# The main program loop
-def phase0(cache, filename, line):
+def phase0(cache, filename, line, file=sys.stdout):
global phase
phase = phase0
@@ -234,9 +236,9 @@ def phase0(cache, filename, line):
includes = IncludeSection()
phase1(cache, filename, line)
else:
- print(line)
+ print(line, file=file)
-def phase1(cache, filename, line):
+def phase1(cache, filename, line, file=sys.stdout):
global phase, includes
phase = phase1
@@ -254,25 +256,39 @@ def phase1(cache, filename, line):
includes.trailing_nl = ''
includes.typedef.append(line)
else:
- includes.print()
+ includes.print(file=file)
includes = None
- phase0(cache, filename, line)
+ phase0(cache, filename, line, file=file)
includes = None
phase = phase0
def main(argv):
- filename = argv[1]
- print(' => {0} {1}'.format(
- shlex.quote(__file__),
- shlex.quote(filename),
- ), file=sys.stderr)
cache = Cache(__file__+'.cache')
- with open(filename) as f:
- for line in f:
- phase(cache, filename, line.rstrip('\n'))
- if includes:
- includes.print()
+ tmpfilename = ''
+ def cleanup():
+ if tmpfilename != '':
+ try:
+ os.unlink(tmpfilename)
+ except FileNotFoundError:
+ pass
+ atexit.register(cleanup)
+ for filename in argv[1:]:
+ tmpfilename = os.path.join(os.path.dirname(filename), '.tmp.'+os.path.basename(filename)+'.tmp')
+ print(' => {0} {1}'.format(
+ shlex.quote(__file__),
+ shlex.quote(filename),
+ ), file=sys.stderr)
+ with open(tmpfilename, 'w') as tmpfile:
+ with open(filename) as f:
+ for line in f:
+ phase(cache, filename, line.rstrip('\n'), file=tmpfile)
+ if includes:
+ includes.print(file=tmpfile)
+ if not filecmp.cmp(filename, tmpfilename):
+ os.rename(tmpfilename, filename)
+ cleanup()
+ tmpfilename = ''
cache.save(__file__+'.cache')
if __name__ == '__main__':