diff --git a/src/qonnx/core/execute_custom_node.py b/src/qonnx/core/execute_custom_node.py index 7acf3792..cd6bb605 100644 --- a/src/qonnx/core/execute_custom_node.py +++ b/src/qonnx/core/execute_custom_node.py @@ -27,10 +27,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import qonnx.custom_op.registry as registry -from qonnx.util.basic import get_preferred_onnx_opset -def execute_custom_node(node, context, graph, onnx_opset_version=get_preferred_onnx_opset()): +def execute_custom_node(node, context, graph, onnx_opset_version): """Call custom implementation to execute a single custom node. Input/output provided via context.""" op_type = node.op_type diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index 3a686f7e..893504de 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -36,15 +36,10 @@ import qonnx.analysis.topology as ta import qonnx.core.execute_custom_node as ex_cu_node from qonnx.custom_op.registry import is_custom_op -from qonnx.util.basic import ( - get_preferred_onnx_opset, - get_sanitize_quant_tensors, - qonnx_make_model, - sanitize_quant_values, -) +from qonnx.util.basic import get_preferred_qonnx_opset, get_sanitize_quant_tensors, qonnx_make_model, sanitize_quant_values -def execute_node(node, context, graph, return_full_exec_context=False, opset_version=get_preferred_onnx_opset()): +def execute_node(node, context, graph, opset_version, return_full_exec_context=False): """Executes a single node by using onnxruntime or with a custom function. Input/output provided via context.""" @@ -158,7 +153,7 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N model_exec_mode = model.get_metadata_prop("exec_mode") if (model_exec_mode is None) or (model_exec_mode == ""): # extract opset version for node-by-node execution - opset_version = model.model.opset_import[0].version + opset_imports = model.get_opset_imports() # execute the model node by node # we can simply walk down the list since the ONNX spec guarantees that it is # topologically sorted @@ -176,7 +171,11 @@ def execute_onnx(model, input_dict, return_full_exec_context=False, start_node=N if get_sanitize_quant_tensors() != 0: # round input values to match quantization annotation execution_context = sanitize_quant_values(model, node.input, execution_context) - execute_node(node, execution_context, graph, return_full_exec_context, opset_version) + if node.domain in opset_imports: + opset_version = opset_imports[node.domain] + else: + opset_version = get_preferred_qonnx_opset() + execute_node(node, execution_context, graph, opset_version, return_full_exec_context) if get_sanitize_quant_tensors() != 0: # round output values to quantization annotation execution_context = sanitize_quant_values(model, node.output, execution_context) diff --git a/src/qonnx/custom_op/base.py b/src/qonnx/custom_op/base.py index 775d9f95..383e453d 100644 --- a/src/qonnx/custom_op/base.py +++ b/src/qonnx/custom_op/base.py @@ -30,15 +30,35 @@ import onnx.numpy_helper as np_helper from abc import ABC, abstractmethod -from qonnx.util.basic import get_by_name, get_preferred_onnx_opset +from qonnx.util.basic import get_by_name, get_preferred_qonnx_opset class CustomOp(ABC): """CustomOp class all custom op nodes are based on. Contains different functions every custom node should have. Some as abstract methods, these have to be - filled when writing a new custom op node.""" + filled when writing a new custom op node. - def __init__(self, onnx_node, onnx_opset_version=get_preferred_onnx_opset()): + Opset Version Support: + CustomOp classes use "since version" semantics matching ONNX operators. + Version is determined by the class name using _vN suffix convention: + + - No suffix (e.g., IntQuant): Version 1 (default) + - _vN suffix (e.g., IntQuant_v2): Version N + + The registry automatically selects the highest version <= requested opset. + + Example: + class IntQuant(CustomOp): + pass # Version 1 (no suffix) + + class IntQuant_v2(CustomOp): + pass # Version 2, covers opset v2-v3 (if no v3 exists) + + class IntQuant_v4(CustomOp): + pass # Version 4, covers opset v4+ + """ + + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): super().__init__() self.onnx_node = onnx_node self.onnx_opset_version = onnx_opset_version diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index 77a048e7..390b0030 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -1,11 +1,17 @@ # Importing registers CustomOps in qonnx.custom_op.channels_last domain -from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization -from qonnx.custom_op.channels_last.conv import Conv -from qonnx.custom_op.channels_last.max_pool import MaxPool +from qonnx.custom_op.channels_last.batch_normalization import ( + BatchNormalization_v1, + BatchNormalization_v9, + BatchNormalization_v14, +) +from qonnx.custom_op.channels_last.conv import Conv_v1 +from qonnx.custom_op.channels_last.max_pool import MaxPool_v1, MaxPool_v10 -# Legacy dictionary for backward compatibility -custom_op = { - "Conv": Conv, - "MaxPool": MaxPool, - "BatchNormalization": BatchNormalization, -} \ No newline at end of file +__all__ = [ + "Conv_v1", + "MaxPool_v1", + "MaxPool_v10", + "BatchNormalization_v1", + "BatchNormalization_v9", + "BatchNormalization_v14", +] diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py index f3b3f872..a49591f4 100644 --- a/src/qonnx/custom_op/channels_last/batch_normalization.py +++ b/src/qonnx/custom_op/channels_last/batch_normalization.py @@ -32,7 +32,7 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp -class BatchNormalization(ChannelsLastWrappedOp): +class BatchNormalization_v1(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: ret_dict[attribute_name] = (dtype, require, default_value, ) @@ -133,3 +133,13 @@ def verify_node(self): ) return info_messages + + +class BatchNormalization_v9(BatchNormalization_v1): + # no relevant changes for channels-last wrapper + pass + + +class BatchNormalization_v14(BatchNormalization_v9): + # no relevant changes for channels-last wrapper + pass diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py index b0ff237b..9d74dd59 100644 --- a/src/qonnx/custom_op/channels_last/conv.py +++ b/src/qonnx/custom_op/channels_last/conv.py @@ -33,7 +33,7 @@ from qonnx.custom_op.general.im2col import compute_conv_output_dim -class Conv(ChannelsLastWrappedOp): +class Conv_v1(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: ret_dict[attribute_name] = (dtype, require, default_value, ) diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py index 383f3008..21a39d1d 100644 --- a/src/qonnx/custom_op/channels_last/max_pool.py +++ b/src/qonnx/custom_op/channels_last/max_pool.py @@ -33,7 +33,7 @@ from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -class MaxPool(ChannelsLastWrappedOp): +class MaxPool_v1(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: ret_dict[attribute_name] = (dtype, require, default_value, ) @@ -171,3 +171,8 @@ def verify_node(self): ) return info_messages + + +class MaxPool_v10(MaxPool_v1): + # no relevant changes for channels-last wrapper + pass diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 2f3896de..becc30b3 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -35,23 +35,22 @@ from qonnx.custom_op.general.intquant import IntQuant from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold -from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d from qonnx.custom_op.general.quant import Quant +from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul -# Legacy dictionary for backward compatibility -custom_op = { - "DebugMarker": DebugMarker, - "QuantAvgPool2d": QuantAvgPool2d, - "MaxPoolNHWC": MaxPoolNHWC, - "GenericPartition": GenericPartition, - "MultiThreshold": MultiThreshold, - "XnorPopcountMatMul": XnorPopcountMatMul, - "Im2Col": Im2Col, - "IntQuant": IntQuant, - "Quant": IntQuant, # Alias - "Trunc": Trunc, - "BipolarQuant": BipolarQuant, - "FloatQuant": FloatQuant, -} \ No newline at end of file +__all__ = [ + "BipolarQuant", + "DebugMarker", + "FloatQuant", + "GenericPartition", + "Im2Col", + "IntQuant", + "MaxPoolNHWC", + "MultiThreshold", + "Quant", + "QuantAvgPool2d", + "Trunc", + "XnorPopcountMatMul", +] diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index eb964fc4..93c6012d 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -97,10 +97,7 @@ def execute_node(self, context, graph): inp_vi = helper.make_tensor_value_info(inp_name, TensorProto.FLOAT, inp.shape) out_vi = helper.make_tensor_value_info(out_name, TensorProto.FLOAT, dummy_out.shape) tmp_graph = helper.make_graph(nodes=[node], name="tmp_graph", inputs=[inp_vi], outputs=[out_vi]) - opset_version = self.onnx_opset_version - opset_imports = [helper.make_opsetid("", opset_version)] - onnx_kwargs = {"opset_imports": opset_imports} - tmp_model = qonnx_make_model(tmp_graph, producer_name="finn", **onnx_kwargs) + tmp_model = qonnx_make_model(tmp_graph, producer_name="finn") tmp_model = ModelWrapper(tmp_model) new_ctx = {inp_name: inp} from qonnx.core.onnx_exec import execute_onnx diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index c0e24071..00617dcf 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.util.basic import qonnx_make_model +from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model class QuantAvgPool2d(CustomOp): @@ -132,7 +132,7 @@ def execute_node(self, context, graph): outputs=[outp], ) - opset_version = self.onnx_opset_version + opset_version = get_preferred_onnx_opset() opset_imports = [helper.make_opsetid("", opset_version)] onnx_kwargs = {"opset_imports": opset_imports} model_avgpool = qonnx_make_model(graph_avgpool, **onnx_kwargs) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index b116f9e1..e9f6f0e7 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,14 +28,15 @@ import importlib import inspect +import warnings from threading import RLock from typing import Dict, List, Optional, Tuple, Type from qonnx.custom_op.base import CustomOp -from qonnx.util.basic import get_preferred_onnx_opset -# Registry keyed by original ONNX domain: (domain, op_type) -> CustomOp class -_OP_REGISTRY: Dict[Tuple[str, str], Type[CustomOp]] = {} +# Nested registry for O(1) lookups: domain -> op_type -> version -> CustomOp class +# Uses "since version" semantics: version N covers opset N until a higher version exists +_OP_REGISTRY: Dict[str, Dict[str, Dict[int, Type[CustomOp]]]] = {} _REGISTRY_LOCK = RLock() @@ -68,92 +69,335 @@ def resolve_domain(domain: str) -> str: return _DOMAIN_ALIASES.get(domain, domain) -def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None: - """Register a custom op directly to a domain at runtime. +def _get_op_type_for_class(cls: Type[CustomOp]) -> str: + """Extract the op_type from a CustomOp class name, stripping _vN suffix if present. - The op_type is automatically derived from the class name. - Useful for testing and experimentation. For production, define CustomOps - in the appropriate module file. + Args: + cls: CustomOp class + + Returns: + op_type string (e.g., "IntQuant_v2" -> "IntQuant") + """ + name = cls.__name__ + # Strip _vN suffix if present + if "_v" in name: + parts = name.split("_v") + if len(parts) == 2 and parts[1].isdigit(): + return parts[0] # IntQuant_v2 -> IntQuant + return name + + +def _get_op_version_for_class(cls: Type[CustomOp]) -> int: + """Extract version from a CustomOp class name. Args: - domain: ONNX domain name (e.g., "qonnx.custom_op.general") - op_class: CustomOp subclass + cls: CustomOp class - Example: - add_op_to_domain("qonnx.custom_op.general", MyTestOp) + Returns: + Opset version (defaults to 1 if no _vN suffix present) """ - if not issubclass(op_class, CustomOp): - raise ValueError(f"{op_class} must be a subclass of CustomOp") + name = cls.__name__ + if "_v" in name: + parts = name.rsplit("_v", 1) + if len(parts) == 2 and parts[1].isdigit(): + return int(parts[1]) + return 1 - op_type = op_class.__name__ - with _REGISTRY_LOCK: - _OP_REGISTRY[(domain, op_type)] = op_class +def _discover_from_custom_op_dict(module, op_type: str, domain: str) -> Dict[int, Type[CustomOp]]: + """Extract CustomOp versions from legacy custom_op dict (backward compatibility). + Supports the old registration pattern: + custom_op = dict() + custom_op["IntQuant"] = IntQuant + custom_op["IntQuant_v2"] = IntQuant_v2 -def _discover_custom_op(domain: str, op_type: str) -> bool: - """Discover and register a single custom op. + Args: + module: The imported module to check + op_type: The specific op type to discover + domain: The domain name (for warnings) + + Returns: + Dict mapping version -> CustomOp class + """ + versions = {} + + if not (hasattr(module, "custom_op") and isinstance(module.custom_op, dict)): + return versions + + # Iterate all dict entries, filter by op_type + for key, obj in module.custom_op.items(): + # Check if this dict key matches the requested op_type + base_name = key.split("_v")[0] if "_v" in key else key + if base_name != op_type: + continue + + if not (inspect.isclass(obj) and issubclass(obj, CustomOp) and obj is not CustomOp): + continue + + try: + version = _get_op_version_for_class(obj) + except ValueError as e: + warnings.warn(str(e)) + continue + + if version in versions: + warnings.warn( + f"Multiple classes found for {domain}.{op_type} version {version}: " + f"{versions[version].__name__} and {obj.__name__}. Using {obj.__name__}." + ) + versions[version] = obj + + return versions + + +def _discover_custom_op_versions(domain: str, op_type: str) -> Dict[int, Type[CustomOp]]: + """Discover all versions of a SPECIFIC custom op without loading entire domain. + + Uses __all__ when available for efficient filtering, otherwise falls back to + full module inspection. Only loads classes matching the requested op_type. Args: domain: The ONNX domain name op_type: The specific op type to discover Returns: - True if op was found and registered, False otherwise + Dict mapping version -> CustomOp class """ module_path = resolve_domain(domain) + versions = {} try: module = importlib.import_module(module_path) except ModuleNotFoundError: - return False + return versions + + # Fast path: use __all__ to find only matching classes + if hasattr(module, "__all__"): + # Filter __all__ to find all versions of THIS op_type + # e.g., op_type="IntQuant" matches ["IntQuant", "IntQuant_v2", "IntQuant_v4"] + candidates = [] + for name in module.__all__: + # Strip _vN suffix to check if it matches + base_name = name.split("_v")[0] if "_v" in name else name + if base_name == op_type: + candidates.append(name) + + # Import ONLY the matching classes (lazy loading) + for name in candidates: + try: + obj = getattr(module, name) + except AttributeError: + continue + + if not (inspect.isclass(obj) and issubclass(obj, CustomOp) and obj is not CustomOp): + continue + + try: + version = _get_op_version_for_class(obj) + except ValueError as e: + warnings.warn(str(e)) + continue + + if version in versions: + warnings.warn( + f"Multiple classes found for {domain}.{op_type} version {version}: " + f"{versions[version].__name__} and {obj.__name__}. Using {obj.__name__}." + ) + versions[version] = obj + + # Backward compatibility: if __all__ didn't have the op, try custom_op dict + if not versions: + versions = _discover_from_custom_op_dict(module, op_type, domain) + + else: + # No __all__ - try legacy dict first (O(1) check, cheaper than full scan) + versions = _discover_from_custom_op_dict(module, op_type, domain) + + # Still nothing? Fallback to full module scan (for external modules) + if not versions: + for name, obj in inspect.getmembers(module, inspect.isclass): + if not issubclass(obj, CustomOp) or obj is CustomOp: + continue + + class_op_type = _get_op_type_for_class(obj) + if class_op_type != op_type: + continue + + try: + version = _get_op_version_for_class(obj) + except ValueError as e: + warnings.warn(str(e)) + continue + + if version in versions: + warnings.warn( + f"Multiple classes found for {domain}.{op_type} version {version}: " + f"{versions[version].__name__} and {obj.__name__}. Using {obj.__name__}." + ) + versions[version] = obj + + return versions + + +def _resolve_version( + available_versions: Dict[int, Type[CustomOp]], requested_version: Optional[int] +) -> Tuple[int, Type[CustomOp]]: + """Resolve which version to use given available and requested versions. + + Uses "since version" semantics: highest version <= requested is selected. + + Resolution strategy: + 1. If requested is None, use highest available version + 2. Try exact match + 3. Use highest version <= requested + 4. Raise KeyError if no suitable version - # Try namespace lookup - op_class = getattr(module, op_type, None) - if inspect.isclass(op_class) and issubclass(op_class, CustomOp): - _OP_REGISTRY[(domain, op_type)] = op_class - return True + Args: + available_versions: Dict of available versions -> CustomOp classes + requested_version: Requested opset version, or None for highest - # Try legacy dict - custom_op_dict = getattr(module, 'custom_op', None) - if isinstance(custom_op_dict, dict): - op_class = custom_op_dict.get(op_type) - if inspect.isclass(op_class) and issubclass(op_class, CustomOp): - _OP_REGISTRY[(domain, op_type)] = op_class - return True + Returns: + Tuple of (resolved_version, CustomOp_class) - return False + Raises: + KeyError: If no suitable version found + """ + if not available_versions: + raise KeyError("No versions available") + + # Strategy 1: If no specific version requested, use highest + if requested_version is None: + highest = max(available_versions.keys()) + return highest, available_versions[highest] + + # Strategy 2: Try exact match + if requested_version in available_versions: + return requested_version, available_versions[requested_version] + + # Strategy 3: Use highest version <= requested (since version semantics) + suitable = [v for v in available_versions.keys() if v <= requested_version] + if suitable: + selected = max(suitable) + return selected, available_versions[selected] + + # Strategy 4: No suitable version found + available_list = sorted(available_versions.keys()) + raise KeyError( + f"No suitable version found. Requested: {requested_version}, " + f"Available: {available_list}. Lowest available version is {available_list[0]}." + ) -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()): +def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None: + """Register a custom op directly to a domain at runtime. + + The op_type and version are automatically derived from the class name. + Useful for testing and experimentation. For production, define CustomOps + in the appropriate module file. + + Args: + domain: ONNX domain name (e.g., "qonnx.custom_op.general") + op_class: CustomOp subclass (version inferred from name) + + Example: + add_op_to_domain("qonnx.custom_op.general", MyTestOp) # v1 + add_op_to_domain("qonnx.custom_op.general", MyTestOp_v2) # v2 + """ + if not issubclass(op_class, CustomOp): + raise ValueError(f"{op_class} must be a subclass of CustomOp") + + op_type = _get_op_type_for_class(op_class) + op_version = _get_op_version_for_class(op_class) + + with _REGISTRY_LOCK: + # Ensure nested dict structure exists + if domain not in _OP_REGISTRY: + _OP_REGISTRY[domain] = {} + if op_type not in _OP_REGISTRY[domain]: + _OP_REGISTRY[domain][op_type] = {} + + _OP_REGISTRY[domain][op_type][op_version] = op_class + + +def getCustomOp(node, onnx_opset_version=None): """Get a custom op instance for an ONNX node. + Uses "since version" semantics: selects highest version <= requested opset. + Lazy loads only the requested op_type using __all__ for efficiency. + Args: node: ONNX node with domain and op_type attributes - onnx_opset_version: ONNX opset version to use + onnx_opset_version: Opset version from model's opset_import, or None for highest Returns: CustomOp instance for the node Raises: - KeyError: If op_type not found in domain + KeyError: If op_type not found in domain or no suitable version available """ op_type = node.op_type domain = node.domain - key = (domain, op_type) with _REGISTRY_LOCK: - if key in _OP_REGISTRY: - return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + # O(1) nested dict lookup to check cache + if domain in _OP_REGISTRY and op_type in _OP_REGISTRY[domain]: + cached_versions = _OP_REGISTRY[domain][op_type] + else: + # Cache miss: discover THIS op only (lazy, uses __all__ for speed) + cached_versions = _discover_custom_op_versions(domain, op_type) + + if not cached_versions: + module_path = resolve_domain(domain) + raise KeyError( + f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). " + f"Ensure it's defined in the module with proper naming (OpName or OpName_vN)." + ) + + # Cache it in nested structure + if domain not in _OP_REGISTRY: + _OP_REGISTRY[domain] = {} + _OP_REGISTRY[domain][op_type] = cached_versions + + # Resolve which version to use + resolved_version, op_class = _resolve_version(cached_versions, onnx_opset_version) + + # Instantiate and return + return op_class(node, onnx_opset_version=resolved_version) + + +def get_supported_versions(domain: str, op_type: str) -> List[int]: + """Get list of supported opset versions for a custom op. + + Returns all "since versions" where the operator was introduced or changed. + + Args: + domain: ONNX domain name + op_type: Operation type name - if _discover_custom_op(domain, op_type): - return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + Returns: + Sorted list of opset versions + + Raises: + KeyError: If op not found + """ + with _REGISTRY_LOCK: + # O(1) check if cached + if domain in _OP_REGISTRY and op_type in _OP_REGISTRY[domain]: + return sorted(_OP_REGISTRY[domain][op_type].keys()) - module_path = resolve_domain(domain) - raise KeyError( - f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). " - f"Ensure it's exported in the module namespace or in the custom_op dict." - ) + # Not cached: discover this op + versions_dict = _discover_custom_op_versions(domain, op_type) + + if not versions_dict: + raise KeyError(f"Op '{op_type}' not found in domain '{domain}'") + + # Cache discovered versions + if domain not in _OP_REGISTRY: + _OP_REGISTRY[domain] = {} + _OP_REGISTRY[domain][op_type] = versions_dict + + return sorted(versions_dict.keys()) def is_custom_op(domain: str, op_type: Optional[str] = None) -> bool: @@ -173,14 +417,15 @@ def is_custom_op(domain: str, op_type: Optional[str] = None) -> bool: with _REGISTRY_LOCK: if op_type is not None: - # Check for specific op - key = (domain, op_type) - if key in _OP_REGISTRY: + # Check for specific op - O(1) with nested dict + if domain in _OP_REGISTRY and op_type in _OP_REGISTRY[domain]: return True - return _discover_custom_op(domain, op_type) + # Try to discover + versions = _discover_custom_op_versions(domain, op_type) + return len(versions) > 0 else: # Check if domain has any registered ops - if any(d == domain for d, _ in _OP_REGISTRY.keys()): + if domain in _OP_REGISTRY and _OP_REGISTRY[domain]: return True # Try to import the domain module as fallback module_path = resolve_domain(domain) @@ -203,12 +448,10 @@ def hasCustomOp(domain: str, op_type: str) -> bool: Returns: True if the op exists, False otherwise """ - import warnings warnings.warn( - "hasCustomOp is deprecated and will be removed in QONNX v1.0. " - "Use is_custom_op instead.", + "hasCustomOp is deprecated and will be removed in QONNX v1.0. " "Use is_custom_op instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) return is_custom_op(domain, op_type) @@ -216,6 +459,9 @@ def hasCustomOp(domain: str, op_type: str) -> bool: def get_ops_in_domain(domain: str) -> List[Tuple[str, Type[CustomOp]]]: """Get all CustomOp classes available in a domain. + Note: Returns unique op_types. If multiple versions exist, returns the highest version. + This function eagerly loads all ops in the domain. + Args: domain: ONNX domain name (e.g., "qonnx.custom_op.general") @@ -227,34 +473,49 @@ def get_ops_in_domain(domain: str) -> List[Tuple[str, Type[CustomOp]]]: for op_name, op_class in ops: print(f"{op_name}: {op_class}") """ - ops = [] module_path = resolve_domain(domain) + ops_dict = {} with _REGISTRY_LOCK: - # Strategy 1: Get cached ops (fast path) - for (d, op_type), op_class in _OP_REGISTRY.items(): - if d == domain: - ops.append((op_type, op_class)) + # Strategy 1: Get cached ops (fast path) - use highest version + if domain in _OP_REGISTRY: + for op_type, versions in _OP_REGISTRY[domain].items(): + if versions: + highest_version = max(versions.keys()) + ops_dict[op_type] = versions[highest_version] # Strategy 2: Discover from module (for uncached ops) + # This uses full scan since we want ALL ops try: module = importlib.import_module(module_path) - # Check namespace exports - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - issubclass(obj, CustomOp) and - obj is not CustomOp and - not name.startswith('_') and - not any(op[0] == name for op in ops)): - ops.append((name, obj)) - - # Check legacy custom_op dict - if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): - for name, cls in module.custom_op.items(): - if not any(op[0] == name for op in ops): - ops.append((name, cls)) + # Use __all__ if available for efficiency + if hasattr(module, "__all__"): + candidates = [(name, getattr(module, name, None)) for name in module.__all__] + candidates = [(n, obj) for n, obj in candidates if obj is not None] + else: + candidates = inspect.getmembers(module, inspect.isclass) + + for name, obj in candidates: + if not (inspect.isclass(obj) and issubclass(obj, CustomOp) and obj is not CustomOp): + continue + + op_type = _get_op_type_for_class(obj) + try: + version = _get_op_version_for_class(obj) + except ValueError: + continue + + # Keep highest version only + if op_type not in ops_dict: + ops_dict[op_type] = obj + else: + # Check if this version is higher + existing_version = _get_op_version_for_class(ops_dict[op_type]) + if version > existing_version: + ops_dict[op_type] = obj + except ModuleNotFoundError: pass # Domain doesn't exist as module, return cached ops only - return ops + return list(ops_dict.items()) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..a00f8a9c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -32,8 +32,8 @@ from onnx import TensorProto, helper from qonnx.core.modelwrapper import ModelWrapper -from qonnx.custom_op import channels_last from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_first_args, to_channels_last_args +from qonnx.custom_op.registry import get_ops_in_domain from qonnx.transformation.base import Transformation from qonnx.transformation.fold_constants import FoldConstants from qonnx.transformation.general import SortGraph @@ -44,7 +44,7 @@ from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly -_channelsLast_node_types = list(channels_last.custom_op.keys()) +_channelsLast_node_types = list([x[0] for x in get_ops_in_domain("qonnx.custom_op.channels_last")]) # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. @@ -270,8 +270,15 @@ def apply(self, model): # Attach to original node n.output[i] = outp_trans_in - # Modify domain + # Modify node domain n.domain = "qonnx.custom_op.channels_last" + opset_imports = model.get_opset_imports() + # Ensure channels_last domain is imported in model + if "qonnx.custom_op.channels_last" not in opset_imports: + # use the same opset for channels last ops as the standard ONNX opset + # (since they are defined based on the standard ops under the hood) + onnx_opset = opset_imports[""] if "" in opset_imports.keys() else opset_imports["ai.onnx"] + model.model.opset_import.append(helper.make_opsetid("qonnx.custom_op.channels_last", onnx_opset)) # Set modified flag graph_modified = True diff --git a/src/qonnx/transformation/fixedpt_quantize.py b/src/qonnx/transformation/fixedpt_quantize.py index 127fa4b1..3b3357ed 100644 --- a/src/qonnx/transformation/fixedpt_quantize.py +++ b/src/qonnx/transformation/fixedpt_quantize.py @@ -41,19 +41,15 @@ def default_op_filter(op): class FixedPointQuantizeParamsFromDict(Transformation): """ - Quantize model parameters to a given fixed-point representation. - The self.max_err dictionary stores the maximum error for each quantized input after calling. - Parameters: - fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point - <<<<<<< HEAD - data type or its canonical name - ======= - data type or its canonical name - >>>>>>> 7dfc4b8 ([Lint] rerun linter, fix errors) - rounding_mode: Rounding mode used for conversion into fixed point. - Default is "ROUND", - possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", - "HALF_UP", "HALF_DOWN"] + Quantize model parameters to a given fixed-point representation. + The self.max_err dictionary stores the maximum error for each quantized input after calling. + Parameters: + fixedpt_dict: Dictionary containing tensor names and their corresponding target fixed-point + data type or its canonical name + rounding_mode: Rounding mode used for conversion into fixed point. + Default is "ROUND", + possible values: ["ROUND", "HALF_EVEN", "CEIL", "FLOOR", "UP", "DOWN", + "HALF_UP", "HALF_DOWN"] """ def __init__(self, fixedpt_dict, rounding_mode="ROUND"): diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 17957d12..cef4f67b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -78,13 +78,15 @@ def is_finn_op(op_type): Use the registry-based is_custom_op for better accuracy and extensibility. """ import warnings + warnings.warn( "is_finn_op is deprecated and will be removed in QONNX v1.0. " "Use 'from qonnx.custom_op.registry import is_custom_op' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) from qonnx.custom_op.registry import is_custom_op + return is_custom_op(op_type) diff --git a/tests/core/test_custom_onnx_exec.py b/tests/core/test_custom_onnx_exec.py index 8eec7156..54b71754 100644 --- a/tests/core/test_custom_onnx_exec.py +++ b/tests/core/test_custom_onnx_exec.py @@ -32,6 +32,8 @@ import qonnx.core.execute_custom_node as ex_cu_node from qonnx.custom_op.registry import getCustomOp +mt_node_version = 1 + def test_execute_custom_node_multithreshold(): inputs = np.ndarray( @@ -155,7 +157,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs execution_context["thresholds"] = threshold_values - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs = np.ndarray( shape=(6, 3, 2, 2), @@ -250,7 +252,7 @@ def test_execute_custom_node_multithreshold(): ) graph_def = helper.make_graph([node_def], "test_model", [v, thresholds], [out]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) outputs_scaled = 2.0 * outputs - 1.0 assert (execution_context["out"] == outputs_scaled).all() @@ -270,7 +272,7 @@ def test_execute_custom_node_multithreshold(): execution_context["v"] = inputs_nhwc graph_def = helper.make_graph([node_def], "test_model", [v_nhwc, thresholds], [out_nhwc]) - ex_cu_node.execute_custom_node(node_def, execution_context, graph_def) + ex_cu_node.execute_custom_node(node_def, execution_context, graph_def, mt_node_version) assert (execution_context["out"] == outputs_nhwc).all() # check the set of allowed values op_inst = getCustomOp(node_def) diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index cde5a321..d1d32546 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -29,10 +29,9 @@ import numpy as np import onnx.parser as oprs -import qonnx.custom_op.general as general from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp +from qonnx.custom_op.registry import add_op_to_domain, getCustomOp class AttrTestOp(CustomOp): @@ -60,7 +59,7 @@ def verify_node(self): def test_attr(): - general.custom_op["AttrTestOp"] = AttrTestOp + add_op_to_domain("qonnx.custom_op.general", AttrTestOp) ishp = (1, 10) wshp = (1, 3) oshp = wshp diff --git a/tests/custom_op/test_customop_version.py b/tests/custom_op/test_customop_version.py new file mode 100644 index 00000000..e0d30c56 --- /dev/null +++ b/tests/custom_op/test_customop_version.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import onnx.parser as oprs + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import add_op_to_domain, getCustomOp + + +class VerTestOp_v1(CustomOp): + def get_nodeattr_types(self): + my_attrs = {"v1_attr": ("i", True, 0)} + return my_attrs + + def make_shape_compatible_op(self, model): + ishape = model.get_tensor_shape(self.onnx_node.input[0]) + return super().make_const_shape_op(ishape) + + def infer_node_datatype(self, model): + node = self.onnx_node + # data type stays the same + dtype = model.get_tensor_datatype(node.input[0]) + model.set_tensor_datatype(node.output[0], dtype) + + def execute_node(self, context, graph): + node = self.onnx_node + context[node.output[0]] = context[node.input[0]] + + def verify_node(self): + pass + + +class VerTestOp_v2(VerTestOp_v1): + def get_nodeattr_types(self): + my_attrs = {"v2_attr": ("i", True, 0)} + return my_attrs + + +class VerTestOp_v3(VerTestOp_v2): + def get_nodeattr_types(self): + my_attrs = {"v3_attr": ("i", True, 0)} + return my_attrs + + +def make_vertest_model(vertest_ver, no_opset_import): + ishp = (1, 10) + oshp = ishp + ishp_str = str(list(ishp)) + oshp_str = str(list(oshp)) + if no_opset_import: + opset_import = "" + else: + opset_import = f', "qonnx.custom_op.general" : {vertest_ver}' + input = f""" + < + ir_version: 7, + opset_import: ["" : 9{opset_import}] + > + agraph (float{ishp_str} in0) => (float{oshp_str} out0) + {{ + out0 = qonnx.custom_op.general.VerTestOp< + v{vertest_ver}_attr={vertest_ver} + >(in0) + }} + """ + model = oprs.parse_model(input) + model = ModelWrapper(model) + return model + + +def test_customop_version(): + # Register test ops with the registry + # The _vN suffix will be automatically stripped to get op_type="VerTestOp" + add_op_to_domain("qonnx.custom_op.general", VerTestOp_v1) + add_op_to_domain("qonnx.custom_op.general", VerTestOp_v2) + add_op_to_domain("qonnx.custom_op.general", VerTestOp_v3) + + # if onnx is lacking the opset import, getCustomOp with no version + # should return the highest available version + model = make_vertest_model(1, True) + inst = getCustomOp(model.graph.node[0]) + # With no opset_import, getCustomOp(None) uses highest version -> v3 + assert isinstance(inst, VerTestOp_v3) + # alternatively, when using ModelWrapper.get_customop_wrapper and onnx is + # lacking the opset import, should fall back to the specified version + inst = model.get_customop_wrapper(model.graph.node[0], fallback_customop_version=2) + assert isinstance(inst, VerTestOp_v2) + + for ver in [1, 2, 3]: + model = make_vertest_model(ver, False) + # use ModelWrapper.get_customop_wrapper for implicit + # fetching of op version + inst = model.get_customop_wrapper(model.graph.node[0]) + assert inst.get_nodeattr(f"v{ver}_attr") == ver + assert inst.onnx_opset_version == ver + # explicitly specify onnx_opset_version in getCustomOp + # note: new code should avoid calling getCustomOp directly like this + # and instead use ModelWrapper.get_customop_wrapper + inst = getCustomOp(model.graph.node[0], onnx_opset_version=ver) + assert inst.get_nodeattr(f"v{ver}_attr") == ver + assert inst.onnx_opset_version == ver + # getCustomOp with no version specified uses highest available + model = make_vertest_model(1, False) + inst = getCustomOp(model.graph.node[0]) + assert isinstance(inst, VerTestOp_v3) # highest version + assert inst.onnx_opset_version == 3 + # requesting v4 should return largest available version (v3 in this case) + model = make_vertest_model(3, False) + inst = getCustomOp(model.graph.node[0], onnx_opset_version=4) + assert isinstance(inst, VerTestOp_v3) + assert inst.onnx_opset_version == 3 diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py index c0f89cde..f792f793 100644 --- a/tests/custom_op/test_floatquant.py +++ b/tests/custom_op/test_floatquant.py @@ -168,7 +168,6 @@ def test_brevitas_vs_qonnx(data): scale = 1.0 exponent_bias = compute_default_exponent_bias(exponent_bit_width) max_val = compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias) - xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, - exponent_bias, sign, max_val).numpy() + xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val).numpy() xq = qonnx_float_quant(x.numpy(), scale, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val) np.testing.assert_array_equal(xq, xq_t) diff --git a/tests/transformation/test_channelslast.py b/tests/transformation/test_channelslast.py index 24e64b4f..92b4964e 100644 --- a/tests/transformation/test_channelslast.py +++ b/tests/transformation/test_channelslast.py @@ -32,9 +32,8 @@ import qonnx.core.onnx_exec as oxe from qonnx.core.modelwrapper import ModelWrapper -from qonnx.custom_op import channels_last from qonnx.custom_op.channels_last.base_wrapped_op import to_channels_last_args -from qonnx.custom_op.registry import getCustomOp +from qonnx.custom_op.registry import get_ops_in_domain, getCustomOp, is_custom_op from qonnx.transformation.channels_last import ( AbsorbChanFirstIntoMatMul, InsertChannelsLastDomainsAndTrafos, @@ -47,7 +46,6 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import is_finn_op from qonnx.util.test import download_model, get_golden_in_and_output, test_model_details from qonnx.util.to_channels_last import to_channels_last @@ -92,7 +90,7 @@ def analysis_testing_for_chanlast_domain(model): "BatchNormalization": 3, } # Check that all wrapped_ops in the registry have a definition here - chanlast_op_types = list(channels_last.custom_op.keys()) + chanlast_op_types = list([x[0] for x in get_ops_in_domain("qonnx.custom_op.channels_last")]) testable_op_types = list(ChanLast_node_types_and_min_dim_input.keys()) for op_name in chanlast_op_types: assert ( @@ -126,7 +124,7 @@ def analysis_test_for_left_transposes(model, test_model, make_input_channels_las def verify_all_nodes(model): result = dict() for n in model.graph.node: - if is_finn_op(n.domain): + if is_custom_op(n.domain): n_instance = getCustomOp(n) verify_result = n_instance.verify_node() result[n.name] = verify_result