diff --git a/triedb/pathdb/history_trienode.go b/triedb/pathdb/history_trienode.go index 3f45b41117d..1004106af9c 100644 --- a/triedb/pathdb/history_trienode.go +++ b/triedb/pathdb/history_trienode.go @@ -19,9 +19,11 @@ package pathdb import ( "bytes" "encoding/binary" + "errors" "fmt" "iter" "maps" + "math" "slices" "sort" "time" @@ -386,12 +388,26 @@ func decodeSingle(keySection []byte, onValue func([]byte, int, int) error) ([]st } // Resolve the entry from key section nShared, nn := binary.Uvarint(keySection[keyOff:]) // key length shared (varint) + if nn <= 0 { + return nil, fmt.Errorf("corrupted varint encoding for nShared at offset %d", keyOff) + } keyOff += nn nUnshared, nn := binary.Uvarint(keySection[keyOff:]) // key length not shared (varint) + if nn <= 0 { + return nil, fmt.Errorf("corrupted varint encoding for nUnshared at offset %d", keyOff) + } keyOff += nn nValue, nn := binary.Uvarint(keySection[keyOff:]) // value length (varint) + if nn <= 0 { + return nil, fmt.Errorf("corrupted varint encoding for nValue at offset %d", keyOff) + } keyOff += nn + // Validate that the values can fit in an int to prevent overflow on 32-bit systems + if nShared > uint64(math.MaxUint32) || nUnshared > uint64(math.MaxUint32) || nValue > uint64(math.MaxUint32) { + return nil, errors.New("key size too large") + } + // Resolve unshared key if keyOff+int(nUnshared) > len(keySection) { return nil, fmt.Errorf("key length too long, unshared key length: %d, off: %d, section size: %d", nUnshared, keyOff, len(keySection)) diff --git a/triedb/pathdb/history_trienode_test.go b/triedb/pathdb/history_trienode_test.go index d6b80f61f56..be4740a9045 100644 --- a/triedb/pathdb/history_trienode_test.go +++ b/triedb/pathdb/history_trienode_test.go @@ -694,7 +694,10 @@ func TestDecodeSingleCorruptedData(t *testing.T) { // Test with corrupted varint in key section corrupted := make([]byte, len(keySection)) copy(corrupted, keySection) - corrupted[5] = 0xFF // Corrupt varint + // Fill first 10 bytes with 0xFF to create a varint overflow (>64 bits) + for i := range 10 { + corrupted[i] = 0xFF + } _, err = decodeSingle(corrupted, nil) if err == nil { t.Fatal("Expected error for corrupted varint")