Skip to content
Closed
6 changes: 3 additions & 3 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import qonnx.util.basic as util
import qonnx.util.onnx as onnxutil
from qonnx.core.datatype import DataType
from qonnx.custom_op.registry import getCustomOp
from qonnx.custom_op.registry import getCustomOp, is_custom_op
from qonnx.transformation.double_to_single_float import DoubleToSingleFloat
from qonnx.transformation.general import (
RemoveStaticGraphInputs,
Expand Down Expand Up @@ -632,11 +632,11 @@ def get_nodes_by_op_type(self, op_type):

def get_finn_nodes(self):
"""Returns a list of nodes where domain == 'qonnx.*'."""
return list(filter(lambda x: util.is_finn_op(x.domain), self.graph.node))
return list(filter(lambda x: is_custom_op(x.domain), self.graph.node))

def get_non_finn_nodes(self):
"""Returns a list of nodes where domain != 'qonnx.*'."""
return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node))
return list(filter(lambda x: not is_custom_op(x.domain), self.graph.node))

def get_node_index(self, node):
"""Returns current index of given node, or None if not found."""
Expand Down
4 changes: 2 additions & 2 deletions src/qonnx/core/onnx_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@

import qonnx.analysis.topology as ta
import qonnx.core.execute_custom_node as ex_cu_node
from qonnx.custom_op.registry import is_custom_op
from qonnx.util.basic import (
get_preferred_qonnx_opset,
get_sanitize_quant_tensors,
is_finn_op,
qonnx_make_model,
sanitize_quant_values,
)
Expand All @@ -49,7 +49,7 @@ def execute_node(node, context, graph, opset_version, return_full_exec_context=F

Input/output provided via context."""

if is_finn_op(node.domain):
if is_custom_op(node.domain, node.op_type):
ex_cu_node.execute_custom_node(node, context, graph, onnx_opset_version=opset_version)
else:
# onnxruntime unfortunately does not implement run_node as defined by ONNX,
Expand Down
22 changes: 21 additions & 1 deletion src/qonnx/custom_op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,27 @@
class CustomOp(ABC):
"""CustomOp class all custom op nodes are based on. Contains different functions
every custom node should have. Some as abstract methods, these have to be
filled when writing a new custom op node."""
filled when writing a new custom op node.

Opset Version Support:
CustomOp classes use "since version" semantics matching ONNX operators.
Version is determined by the class name using _vN suffix convention:

- No suffix (e.g., IntQuant): Version 1 (default)
- _vN suffix (e.g., IntQuant_v2): Version N

The registry automatically selects the highest version <= requested opset.

Example:
class IntQuant(CustomOp):
pass # Version 1 (no suffix)

class IntQuant_v2(CustomOp):
pass # Version 2, covers opset v2-v3 (if no v3 exists)

class IntQuant_v4(CustomOp):
pass # Version 4, covers opset v4+
"""

def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
super().__init__()
Expand Down
27 changes: 2 additions & 25 deletions src/qonnx/custom_op/channels_last/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,6 @@
# Importing registers CustomOps in qonnx.custom_op.channels_last domain
from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization
from qonnx.custom_op.channels_last.conv import Conv
from qonnx.custom_op.channels_last.max_pool import MaxPool

# channels-last ops are defined by the underlying ONNX standard op
# thus, we can define them for any version of the original op
# so we emulate a custom op dictionary that mimics the support for any
# {ChannelsLastOp}_vX instead of hardcoding what versions are supported


class ChannelsLastCustomOpDict(dict):
def __init__(self):
self._custom_ops = {"Conv": Conv, "MaxPool": MaxPool, "BatchNormalization": BatchNormalization}

def __getitem__(self, key):
base_key = key.split("_v")[0] # Extract base key (e.g., Conv from Conv_v13)
if base_key in self._custom_ops:
return self._custom_ops[base_key]
raise KeyError(f"Channels-last CustomOp '{key}' not found.")

def __contains__(self, key):
base_key = key.split("_v")[0]
return base_key in self._custom_ops

def keys(self):
return self._custom_ops.keys()


custom_op = ChannelsLastCustomOpDict()
__all__ = ["Conv", "MaxPool", "BatchNormalization"]
Comment on lines -5 to +6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not respect the semantics (channels-last ops being implicitly defined for any version of the standard op) in the deleted code, instead all channels-last ops will now be registered as v1, meaning that they will instantiate ai.onnx v1 standard nodes under the hood. can you address this?

if this is tricky to fix in the new system, one band-aid we can apply is instead registering these channels-last ops as v11 which is the current preferred ONNX opset version. in this way they would do something slightly more reasonable.

1 change: 1 addition & 0 deletions src/qonnx/custom_op/channels_last/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@


class BatchNormalization(ChannelsLastWrappedOp):

def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
Expand Down
1 change: 1 addition & 0 deletions src/qonnx/custom_op/channels_last/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


class Conv(ChannelsLastWrappedOp):

def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
Expand Down
1 change: 1 addition & 0 deletions src/qonnx/custom_op/channels_last/max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


class MaxPool(ChannelsLastWrappedOp):

def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
Expand Down
43 changes: 16 additions & 27 deletions src/qonnx/custom_op/general/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Importing registers CustomOps in qonnx.custom_op.general domain
from qonnx.custom_op.general.bipolar_quant import BipolarQuant
from qonnx.custom_op.general.debugmarker import DebugMarker
from qonnx.custom_op.general.floatquant import FloatQuant
Expand All @@ -35,33 +36,21 @@
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
from qonnx.custom_op.general.multithreshold import MultiThreshold
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
from qonnx.custom_op.general.quant import Quant
from qonnx.custom_op.general.trunc import Trunc
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul

custom_op = dict()

custom_op["DebugMarker"] = DebugMarker
custom_op["QuantAvgPool2d"] = QuantAvgPool2d
custom_op["MaxPoolNHWC"] = MaxPoolNHWC
custom_op["GenericPartition"] = GenericPartition
custom_op["MultiThreshold"] = MultiThreshold
custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
custom_op["Im2Col"] = Im2Col
custom_op["IntQuant"] = IntQuant
custom_op["Quant"] = IntQuant
custom_op["Trunc"] = Trunc
custom_op["BipolarQuant"] = BipolarQuant
custom_op["FloatQuant"] = FloatQuant

custom_op["DebugMarker_v1"] = DebugMarker
custom_op["QuantAvgPool2d_v1"] = QuantAvgPool2d
custom_op["MaxPoolNHWC_v1"] = MaxPoolNHWC
custom_op["GenericPartition_v1"] = GenericPartition
custom_op["MultiThreshold_v1"] = MultiThreshold
custom_op["XnorPopcountMatMul_v1"] = XnorPopcountMatMul
custom_op["Im2Col_v1"] = Im2Col
custom_op["IntQuant_v1"] = IntQuant
custom_op["Quant_v1"] = IntQuant
custom_op["Trunc_v1"] = Trunc
custom_op["BipolarQuant_v1"] = BipolarQuant
custom_op["FloatQuant_v1"] = FloatQuant
__all__ = [
"BipolarQuant",
"DebugMarker",
"FloatQuant",
"GenericPartition",
"Im2Col",
"IntQuant",
"MaxPoolNHWC",
"MultiThreshold",
"Quant",
"QuantAvgPool2d",
"Trunc",
"XnorPopcountMatMul",
]
1 change: 1 addition & 0 deletions src/qonnx/custom_op/general/debugmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@


class DebugMarker(CustomOp):

def get_nodeattr_types(self):
return {"export_debug_name": ("s", True, "")}

Expand Down
1 change: 1 addition & 0 deletions src/qonnx/custom_op/general/im2col.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def im2col_indices_nchw(


class Im2Col(CustomOp):

def get_nodeattr_types(self):
return {
# stride and shape of convolution kernel
Expand Down
Loading