Skip to content

Commit 7c1e277

Browse files
committed
Fix some issues in the cross-module-inlining complexity estimates.
1 parent d86b8d5 commit 7c1e277

File tree

2 files changed

+174
-135
lines changed

2 files changed

+174
-135
lines changed

typed_python/compiler/native_compiler/native_ast_to_llvm.py

Lines changed: 173 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -28,136 +28,6 @@
2828
import os
2929

3030

31-
def computeFunctionComplexity(functionBody):
32-
if functionBody is None or isinstance(functionBody, str):
33-
return 0
34-
35-
if functionBody.matches.External:
36-
return 0
37-
38-
if functionBody.matches.Internal:
39-
return computeFunctionComplexity(functionBody.body)
40-
41-
if functionBody.matches.Comment:
42-
return computeFunctionComplexity(functionBody.expr)
43-
44-
if functionBody.matches.Load:
45-
return computeFunctionComplexity(functionBody.ptr)
46-
47-
if functionBody.matches.Store:
48-
return (
49-
computeFunctionComplexity(functionBody.ptr)
50-
+ computeFunctionComplexity(functionBody.val)
51-
)
52-
53-
if functionBody.matches.AtomicAdd:
54-
return (
55-
computeFunctionComplexity(functionBody.ptr)
56-
+ computeFunctionComplexity(functionBody.val)
57-
)
58-
59-
if functionBody.matches.Cast:
60-
return computeFunctionComplexity(functionBody.left)
61-
62-
if functionBody.matches.Binop:
63-
return (
64-
computeFunctionComplexity(functionBody.left)
65-
+ computeFunctionComplexity(functionBody.right)
66-
)
67-
68-
if functionBody.matches.Unaryop:
69-
return computeFunctionComplexity(functionBody.operand)
70-
71-
if functionBody.matches.StructElementByIndex:
72-
return computeFunctionComplexity(functionBody.left)
73-
74-
if functionBody.matches.ElementPtr:
75-
return computeFunctionComplexity(functionBody.left) + sum(
76-
computeFunctionComplexity(o) for o in functionBody.offsets
77-
)
78-
79-
if functionBody.matches.Call:
80-
return sum(
81-
computeFunctionComplexity(o) for o in functionBody.args
82-
)
83-
84-
if functionBody.matches.MakeStruct:
85-
return sum(
86-
computeFunctionComplexity(o[1]) for o in functionBody.args
87-
)
88-
89-
if functionBody.matches.Branch:
90-
return (
91-
computeFunctionComplexity(functionBody.cond)
92-
+ computeFunctionComplexity(functionBody.true)
93-
+ computeFunctionComplexity(functionBody.false)
94-
)
95-
96-
if functionBody.matches.Throw:
97-
return (
98-
computeFunctionComplexity(functionBody.expr)
99-
)
100-
101-
if functionBody.matches.TryCatch:
102-
return (
103-
computeFunctionComplexity(functionBody.expr)
104-
+ computeFunctionComplexity(functionBody.handler)
105-
)
106-
107-
if functionBody.matches.ExceptionPropagator:
108-
return (
109-
computeFunctionComplexity(functionBody.expr)
110-
+ computeFunctionComplexity(functionBody.handler)
111-
)
112-
113-
if functionBody.matches.While:
114-
return (
115-
computeFunctionComplexity(functionBody.cond)
116-
+ computeFunctionComplexity(functionBody.while_true)
117-
+ computeFunctionComplexity(functionBody.orelse)
118-
)
119-
120-
if functionBody.matches.Return:
121-
return (
122-
computeFunctionComplexity(functionBody.arg)
123-
)
124-
125-
if functionBody.matches.Let:
126-
return (
127-
computeFunctionComplexity(functionBody.val)
128-
+ computeFunctionComplexity(functionBody.within)
129-
)
130-
131-
if functionBody.matches.Finally:
132-
return (
133-
computeFunctionComplexity(functionBody.expr)
134-
+ sum(
135-
computeFunctionComplexity(o) for o in functionBody.teardowns
136-
)
137-
)
138-
139-
if functionBody.matches.Sequence:
140-
return sum(
141-
computeFunctionComplexity(o) for o in functionBody.vals
142-
)
143-
144-
if functionBody.matches.ApplyIntermediates:
145-
return (
146-
computeFunctionComplexity(functionBody.base)
147-
+ sum(
148-
computeFunctionComplexity(o) for o in functionBody.intermediates
149-
)
150-
)
151-
152-
# Teardown
153-
if functionBody.matches.ByTag or functionBody.matches.Always:
154-
return (
155-
computeFunctionComplexity(functionBody.expr)
156-
)
157-
158-
return 1
159-
160-
16131
class NativeAstToLlvmConverter:
16232
def __init__(self):
16333
object.__init__(self)
@@ -174,6 +44,170 @@ def __init__(self):
17444
self._printAllNativeCalls = os.getenv("TP_COMPILER_LOG_NATIVE_CALLS")
17545
self.verbose = False
17646

47+
def computeFunctionComplexity(self, functionBody, stack=()):
48+
def recurse(subFunc):
49+
return self.computeFunctionComplexity(subFunc, stack)
50+
51+
if functionBody is None or isinstance(functionBody, str):
52+
return 0
53+
54+
if functionBody.matches.External:
55+
return 0
56+
57+
if functionBody.matches.Internal:
58+
return recurse(functionBody.body)
59+
60+
if functionBody.matches.Comment:
61+
return recurse(functionBody.expr)
62+
63+
if functionBody.matches.Load:
64+
return recurse(functionBody.ptr) + 1
65+
66+
if functionBody.matches.Store:
67+
return (
68+
recurse(functionBody.ptr)
69+
+ recurse(functionBody.val)
70+
) + 1
71+
72+
if functionBody.matches.AtomicAdd:
73+
return (
74+
recurse(functionBody.ptr)
75+
+ recurse(functionBody.val)
76+
) + 1
77+
78+
if functionBody.matches.Cast:
79+
return recurse(functionBody.left) + 1
80+
81+
if functionBody.matches.Binop:
82+
return (
83+
recurse(functionBody.left)
84+
+ recurse(functionBody.right)
85+
) + 1
86+
87+
if functionBody.matches.Unaryop:
88+
return recurse(functionBody.operand) + 1
89+
90+
if functionBody.matches.StructElementByIndex:
91+
return recurse(functionBody.left) + 1
92+
93+
if functionBody.matches.ElementPtr:
94+
return recurse(functionBody.left) + sum(
95+
recurse(o) for o in functionBody.offsets
96+
) + 1
97+
98+
if functionBody.matches.Call:
99+
if functionBody.target.matches.Pointer:
100+
calleeComplexity = recurse(functionBody.target.expr)
101+
elif functionBody.target.target.external:
102+
calleeComplexity = 1
103+
else:
104+
if functionBody.target.target.name in stack:
105+
calleeComplexity = 1e6
106+
else:
107+
calleeComplexity = self.totalFunctionComplexity(functionBody.target.target.name, stack)
108+
109+
return sum(
110+
recurse(o) for o in functionBody.args
111+
) + calleeComplexity + 1
112+
113+
if functionBody.matches.MakeStruct:
114+
return sum(
115+
recurse(o[1]) for o in functionBody.args
116+
) + 1
117+
118+
if functionBody.matches.Branch:
119+
return (
120+
recurse(functionBody.cond)
121+
+ recurse(functionBody.true)
122+
+ recurse(functionBody.false)
123+
) + 1
124+
125+
if functionBody.matches.Throw:
126+
return (
127+
recurse(functionBody.expr)
128+
) + 1
129+
130+
if functionBody.matches.TryCatch:
131+
return (
132+
recurse(functionBody.expr)
133+
+ recurse(functionBody.handler)
134+
) + 1
135+
136+
if functionBody.matches.ExceptionPropagator:
137+
return (
138+
recurse(functionBody.expr)
139+
+ recurse(functionBody.handler)
140+
) + 1
141+
142+
if functionBody.matches.While:
143+
return (
144+
recurse(functionBody.cond)
145+
+ recurse(functionBody.while_true)
146+
+ recurse(functionBody.orelse)
147+
) + 1
148+
149+
if functionBody.matches.Return:
150+
return (
151+
recurse(functionBody.arg)
152+
) + 1
153+
154+
if functionBody.matches.Let:
155+
return (
156+
recurse(functionBody.val)
157+
+ recurse(functionBody.within)
158+
) + 1
159+
160+
if functionBody.matches.Finally:
161+
return (
162+
recurse(functionBody.expr)
163+
+ sum(
164+
recurse(o) for o in functionBody.teardowns
165+
)
166+
) + 1
167+
168+
if functionBody.matches.Sequence:
169+
return sum(
170+
recurse(o) for o in functionBody.vals
171+
) + 1
172+
173+
if functionBody.matches.ApplyIntermediates:
174+
return (
175+
recurse(functionBody.base)
176+
+ sum(
177+
recurse(o) for o in functionBody.intermediates
178+
)
179+
) + 1
180+
181+
if functionBody.matches.ActivatesTeardown:
182+
return 1
183+
184+
if functionBody.matches.GlobalVariable:
185+
return 1
186+
187+
if functionBody.matches.StackSlot:
188+
if hasattr(functionBody, 'expr'):
189+
return recurse(functionBody.expr)
190+
else:
191+
return 1
192+
193+
# ExpressionIntermediates
194+
if functionBody.matches.Effect:
195+
return recurse(functionBody.expr)
196+
197+
if functionBody.matches.Terminal:
198+
return recurse(functionBody.expr)
199+
200+
if functionBody.matches.Simple:
201+
return recurse(functionBody.expr)
202+
203+
# Teardown
204+
if functionBody.matches.ByTag or functionBody.matches.Always:
205+
return (
206+
recurse(functionBody.expr)
207+
) + 1
208+
209+
return 1
210+
177211
def addExternallyProvidedFunctions(self, functionNameToDefinition):
178212
"""Provide type signatures for a set of external functions."""
179213

@@ -200,7 +234,7 @@ def addExternallyProvidedFunctions(self, functionNameToDefinition):
200234
self._functions_by_name[name].linkage = 'external'
201235
self._function_definitions[name] = function
202236

203-
def totalFunctionComplexity(self, name):
237+
def totalFunctionComplexity(self, name, stack=()):
204238
"""Return the total number of instructions contained in a function.
205239
206240
The function must already have been defined in a prior parss. We use this
@@ -209,8 +243,9 @@ def totalFunctionComplexity(self, name):
209243
if name in self._function_complexity:
210244
return self._function_complexity[name]
211245

212-
self._function_complexity[name] = computeFunctionComplexity(
213-
self._function_definitions[name].body
246+
self._function_complexity[name] = self.computeFunctionComplexity(
247+
self._function_definitions[name].body,
248+
stack + (name,)
214249
)
215250

216251
return self._function_complexity[name]
@@ -295,6 +330,7 @@ def add_functions(self, names_to_definitions):
295330
globalDefinitions = {}
296331
globalDefinitionsLlvmValues = {}
297332
extraDefinitions = {}
333+
totalInlines = 0
298334

299335
while names_to_definitions:
300336
for name in sorted(names_to_definitions):
@@ -357,7 +393,10 @@ def add_functions(self, names_to_definitions):
357393
# each function listed here was deemed 'inlinable', which means that we
358394
# want to repeat its definition in this particular module.
359395
for name in self._inlineRequests:
360-
names_to_definitions[name] = self._function_definitions[name]
396+
if name not in functionsDefinedHere:
397+
names_to_definitions[name] = self._function_definitions[name]
398+
totalInlines += 1
399+
361400
self._inlineRequests.clear()
362401

363402
# define a function that accepts a pointer and fills it out with a table of pointer values

typed_python/compiler/native_compiler/native_ast_to_llvm_function_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
pointer_size = 8
3333

3434

35-
CROSS_MODULE_INLINE_COMPLEXITY = 8
35+
CROSS_MODULE_INLINE_COMPLEXITY = 40
3636

3737

3838
_type_to_identified_type_cache = {}

0 commit comments

Comments
 (0)