1+ from __future__ import annotations
2+
13import abc
24from contextlib import suppress
5+ from typing import Any , Callable
36
47import cloudpickle
58
69from adaptive .utils import _RequireAttrsABCMeta , load , save
710
811
9- def uses_nth_neighbors (n : int ):
12+ def uses_nth_neighbors (n : int ) -> Callable :
1013 """Decorator to specify how many neighboring intervals the loss function uses.
1114
1215 Wraps loss functions to indicate that they expect intervals together
@@ -82,10 +85,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8285 """
8386
8487 data : dict
85- npoints : int
8688 pending_points : set
89+ function : Callable
90+
91+ @property
92+ @abc .abstractmethod
93+ def npoints (self ) -> int :
94+ """Number of learned points."""
8795
88- def tell (self , x , y ):
96+ def tell (self , x : Any , y ) -> None :
8997 """Tell the learner about a single value.
9098
9199 Parameters
@@ -95,7 +103,7 @@ def tell(self, x, y):
95103 """
96104 self .tell_many ([x ], [y ])
97105
98- def tell_many (self , xs , ys ) :
106+ def tell_many (self , xs : Any , ys : Any ) -> None :
99107 """Tell the learner about some values.
100108
101109 Parameters
@@ -116,7 +124,7 @@ def remove_unfinished(self):
116124 """Remove uncomputed data from the learner."""
117125
118126 @abc .abstractmethod
119- def loss (self , real = True ):
127+ def loss (self , real : bool = True ) -> float :
120128 """Return the loss for the current state of the learner.
121129
122130 Parameters
@@ -128,7 +136,7 @@ def loss(self, real=True):
128136 """
129137
130138 @abc .abstractmethod
131- def ask (self , n , tell_pending = True ):
139+ def ask (self , n : int , tell_pending : bool = True ):
132140 """Choose the next 'n' points to evaluate.
133141
134142 Parameters
@@ -146,7 +154,7 @@ def _get_data(self):
146154 pass
147155
148156 @abc .abstractmethod
149- def _set_data (self ):
157+ def _set_data (self , data : Any ):
150158 pass
151159
152160 @abc .abstractmethod
@@ -164,7 +172,7 @@ def copy_from(self, other):
164172 """
165173 self ._set_data (other ._get_data ())
166174
167- def save (self , fname , compress = True ):
175+ def save (self , fname : str , compress : bool = True ) -> None :
168176 """Save the data of the learner into a pickle file.
169177
170178 Parameters
@@ -178,7 +186,7 @@ def save(self, fname, compress=True):
178186 data = self ._get_data ()
179187 save (fname , data , compress )
180188
181- def load (self , fname , compress = True ):
189+ def load (self , fname : str , compress : bool = True ) -> None :
182190 """Load the data of a learner from a pickle file.
183191
184192 Parameters
@@ -193,8 +201,8 @@ def load(self, fname, compress=True):
193201 data = load (fname , compress )
194202 self ._set_data (data )
195203
196- def __getstate__ (self ):
204+ def __getstate__ (self ) -> dict [ str , Any ] :
197205 return cloudpickle .dumps (self .__dict__ )
198206
199- def __setstate__ (self , state ) :
207+ def __setstate__ (self , state : dict [ str , Any ]) -> None :
200208 self .__dict__ = cloudpickle .loads (state )
0 commit comments