// Code generated by "stringer -type Asn1BER"; DO NOT EDIT.
package gosnmp
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[EndOfContents-0]
_ = x[UnknownType-0]
_ = x[Boolean-1]
_ = x[Integer-2]
_ = x[BitString-3]
_ = x[OctetString-4]
_ = x[Null-5]
_ = x[ObjectIdentifier-6]
_ = x[ObjectDescription-7]
_ = x[IPAddress-64]
_ = x[Counter32-65]
_ = x[Gauge32-66]
_ = x[TimeTicks-67]
_ = x[Opaque-68]
_ = x[NsapAddress-69]
_ = x[Counter64-70]
_ = x[Uinteger32-71]
_ = x[OpaqueFloat-120]
_ = x[OpaqueDouble-121]
_ = x[NoSuchObject-128]
_ = x[NoSuchInstance-129]
_ = x[EndOfMibView-130]
}
const (
_Asn1BER_name_0 = "EndOfContentsBooleanIntegerBitStringOctetStringNullObjectIdentifierObjectDescription"
_Asn1BER_name_1 = "IPAddressCounter32Gauge32TimeTicksOpaqueNsapAddressCounter64Uinteger32"
_Asn1BER_name_2 = "OpaqueFloatOpaqueDouble"
_Asn1BER_name_3 = "NoSuchObjectNoSuchInstanceEndOfMibView"
)
var (
_Asn1BER_index_0 = [...]uint8{0, 13, 20, 27, 36, 47, 51, 67, 84}
_Asn1BER_index_1 = [...]uint8{0, 9, 18, 25, 34, 40, 51, 60, 70}
_Asn1BER_index_2 = [...]uint8{0, 11, 23}
_Asn1BER_index_3 = [...]uint8{0, 12, 26, 38}
)
func (i Asn1BER) String() string {
switch {
case i <= 7:
return _Asn1BER_name_0[_Asn1BER_index_0[i]:_Asn1BER_index_0[i+1]]
case 64 <= i && i <= 71:
i -= 64
return _Asn1BER_name_1[_Asn1BER_index_1[i]:_Asn1BER_index_1[i+1]]
case 120 <= i && i <= 121:
i -= 120
return _Asn1BER_name_2[_Asn1BER_index_2[i]:_Asn1BER_index_2[i+1]]
case 128 <= i && i <= 130:
i -= 128
return _Asn1BER_name_3[_Asn1BER_index_3[i]:_Asn1BER_index_3[i+1]]
default:
return "Asn1BER(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gosnmp
import (
"context"
"crypto/rand"
"fmt"
"math"
"math/big"
"net"
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"
)
const (
// MaxOids is the maximum number of OIDs permitted in a single call,
// otherwise error. MaxOids too high can cause remote devices to fail
// strangely. 60 seems to be a common value that works, but you will want
// to change this in the GoSNMP struct
MaxOids = 60
// Max oid sub-identifier value
// https://tools.ietf.org/html/rfc2578#section-7.1.3
MaxObjectSubIdentifierValue = 4294967295
// Java SNMP uses 50, snmp-net uses 10
defaultMaxRepetitions = 50
// "udp" and "tcp" are used regularly, prevent 'goconst' complaints
udp = "udp"
tcp = "tcp"
)
// GoSNMP represents GoSNMP library state.
type GoSNMP struct {
// Conn is net connection to use, typically established using GoSNMP.Connect().
Conn net.Conn
// Target is an ipv4 address.
Target string
// Port is a port.
Port uint16
// Transport is the transport protocol to use ("udp" or "tcp"); if unset "udp" will be used.
Transport string
// Community is an SNMP Community string.
Community string
// Version is an SNMP Version.
Version SnmpVersion
// Context allows for overall deadlines and cancellation.
Context context.Context
// Timeout is the timeout for one SNMP request/response.
Timeout time.Duration
// Set the number of retries to attempt.
Retries int
// Double timeout in each retry.
ExponentialTimeout bool
// Logger is the GoSNMP.Logger to use for debugging.
// For verbose logging to stdout:
// x.Logger = NewLogger(log.New(os.Stdout, "", 0))
// For Release builds, you can turn off logging entirely by using the go build tag "gosnmp_nodebug" even if the logger was installed.
Logger Logger
// Message hook methods allow passing in a functions at various points in the packet handling.
// For example, this can be used to collect packet timing, add metrics, or implement tracing.
/*
*/
// PreSend is called before a packet is sent.
PreSend func(*GoSNMP)
// OnSent is called when a packet is sent.
OnSent func(*GoSNMP)
// OnRecv is called when a packet is received.
OnRecv func(*GoSNMP)
// OnRetry is called when a retry attempt is done.
OnRetry func(*GoSNMP)
// OnFinish is called when the request completed.
OnFinish func(*GoSNMP)
// MaxOids is the maximum number of oids allowed in a Get().
// (default: MaxOids)
MaxOids int
// MaxRepetitions sets the GETBULK max-repetitions used by BulkWalk*
// Unless MaxRepetitions is specified it will use defaultMaxRepetitions (50)
// This may cause issues with some devices, if so set MaxRepetitions lower.
// See comments in https://github.com/gosnmp/gosnmp/issues/100
MaxRepetitions uint32
// Deprecated: This parameter is not used and ignored
NonRepeaters int
// UseUnconnectedUDPSocket if set, changes net.Conn to be unconnected UDP socket.
// Some multi-homed network gear isn't smart enough to send SNMP responses
// from the address it received the requests on. To work around that,
// we open unconnected UDP socket and use sendto/recvfrom.
UseUnconnectedUDPSocket bool
// If Control is not nil, it is called after creating the network
// connection but before actually dialing.
//
// Can be used when UseUnconnectedUDPSocket is set to false or when using TCP
// in scenario where specific options on the underlying socket are nedded.
// Refer to https://pkg.go.dev/net#Dialer
Control func(network, address string, c syscall.RawConn) error
// LocalAddr is the local address in the format "address:port" to use when connecting an Target address.
// If the port parameter is empty or "0", as in
// "127.0.0.1:" or "[::1]:0", a port number is automatically (random) chosen.
LocalAddr string
// netsnmp has '-C APPOPTS - set various application specific behaviours'
//
// - 'c: do not check returned OIDs are increasing' - use AppOpts = map[string]interface{"c":true} with
// Walk() or BulkWalk(). The library user needs to implement their own policy for terminating walks.
// - 'p,i,I,t,E' -> pull requests welcome
AppOpts map[string]any
// Internal - used to sync requests to responses.
requestID uint32
random uint32
rxBuf *[rxBufSize]byte // has to be pointer due to https://github.com/golang/go/issues/11728
// MsgFlags is an SNMPV3 MsgFlags.
MsgFlags SnmpV3MsgFlags
// SecurityModel is an SNMPV3 Security Model.
SecurityModel SnmpV3SecurityModel
// SecurityParameters is an SNMPV3 Security Model parameters struct.
SecurityParameters SnmpV3SecurityParameters
// TrapSecurityParametersTable is a mapping of identifiers to corresponding SNMP V3 Security Model parameters
// right now only supported for receiving traps, variable name to make that clear
TrapSecurityParametersTable *SnmpV3SecurityParametersTable
// ContextEngineID is SNMPV3 ContextEngineID in ScopedPDU.
ContextEngineID string
// ContextName is SNMPV3 ContextName in ScopedPDU
ContextName string
// Internal - used to sync requests to responses - snmpv3.
msgID uint32
// Internal - we use to send packets if using unconnected socket.
uaddr *net.UDPAddr
// Internal - mutual exclusion allows us to idempotently perform operations
mu sync.Mutex
}
// Default connection settings
//
//nolint:gochecknoglobals
var Default = &GoSNMP{
Port: 161,
Transport: udp,
Community: "public",
Version: Version2c,
Timeout: time.Duration(2) * time.Second,
Retries: 3,
ExponentialTimeout: true,
MaxOids: MaxOids,
}
// SnmpPDU will be used when doing SNMP Set's
type SnmpPDU struct {
// The value to be set by the SNMP set, or the value when
// sending a trap
Value any
// Name is an oid in string format eg ".1.3.6.1.4.9.27"
Name string
// The type of the value eg Integer
Type Asn1BER
}
const AsnContext = 0x80
const AsnExtensionID = 0x1F
const AsnExtensionTag = (AsnContext | AsnExtensionID) // 0x9F
//go:generate stringer -type Asn1BER
// Asn1BER is the type of the SNMP PDU
type Asn1BER byte
// Asn1BER's - http://www.ietf.org/rfc/rfc1442.txt
const (
EndOfContents Asn1BER = 0x00
UnknownType Asn1BER = 0x00
Boolean Asn1BER = 0x01
Integer Asn1BER = 0x02
BitString Asn1BER = 0x03
OctetString Asn1BER = 0x04
Null Asn1BER = 0x05
ObjectIdentifier Asn1BER = 0x06
ObjectDescription Asn1BER = 0x07
IPAddress Asn1BER = 0x40
Counter32 Asn1BER = 0x41
Gauge32 Asn1BER = 0x42
TimeTicks Asn1BER = 0x43
Opaque Asn1BER = 0x44
NsapAddress Asn1BER = 0x45
Counter64 Asn1BER = 0x46
Uinteger32 Asn1BER = 0x47
OpaqueFloat Asn1BER = 0x78
OpaqueDouble Asn1BER = 0x79
NoSuchObject Asn1BER = 0x80
NoSuchInstance Asn1BER = 0x81
EndOfMibView Asn1BER = 0x82
)
//go:generate stringer -type SNMPError
// SNMPError is the type for standard SNMP errors.
type SNMPError uint8
// SNMP Errors
const (
NoError SNMPError = iota // No error occurred. This code is also used in all request PDUs, since they have no error status to report.
TooBig // The size of the Response-PDU would be too large to transport.
NoSuchName // The name of a requested object was not found.
BadValue // A value in the request didn't match the structure that the recipient of the request had for the object. For example, an object in the request was specified with an incorrect length or type.
ReadOnly // An attempt was made to set a variable that has an Access value indicating that it is read-only.
GenErr // An error occurred other than one indicated by a more specific error code in this table.
NoAccess // Access was denied to the object for security reasons.
WrongType // The object type in a variable binding is incorrect for the object.
WrongLength // A variable binding specifies a length incorrect for the object.
WrongEncoding // A variable binding specifies an encoding incorrect for the object.
WrongValue // The value given in a variable binding is not possible for the object.
NoCreation // A specified variable does not exist and cannot be created.
InconsistentValue // A variable binding specifies a value that could be held by the variable but cannot be assigned to it at this time.
ResourceUnavailable // An attempt to set a variable required a resource that is not available.
CommitFailed // An attempt to set a particular variable failed.
UndoFailed // An attempt to set a particular variable as part of a group of variables failed, and the attempt to then undo the setting of other variables was not successful.
AuthorizationError // A problem occurred in authorization.
NotWritable // The variable cannot be written or created.
InconsistentName // The name in a variable binding specifies a variable that does not exist.
)
//
// Public Functions (main interface)
//
// Connect creates and opens a socket. Because UDP is a connectionless
// protocol, you won't know if the remote host is responding until you send
// packets. Neither will you know if the host is regularly disappearing and reappearing.
//
// For historical reasons (ie this is part of the public API), the method won't
// be renamed to Dial().
func (x *GoSNMP) Connect() error {
return x.connect("")
}
// ConnectIPv4 forces an IPv4-only connection
func (x *GoSNMP) ConnectIPv4() error {
return x.connect("4")
}
// ConnectIPv6 forces an IPv6-only connection
func (x *GoSNMP) ConnectIPv6() error {
return x.connect("6")
}
// Close closes the underlaying connection.
//
// This method is safe to call multiple times and from concurrent goroutines.
// Only the first call will close the connection; subsequent calls are no-ops.
func (x *GoSNMP) Close() error {
x.mu.Lock()
defer x.mu.Unlock()
if x.Conn == nil {
return nil
}
err := x.Conn.Close()
x.Conn = nil
return err
}
// connect to address addr on the given network
//
// https://golang.org/pkg/net/#Dial gives acceptable network values as:
//
// "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), "udp", "udp4" (IPv4-only),"udp6" (IPv6-only), "ip",
// "ip4" (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and "unixpacket"
func (x *GoSNMP) connect(networkSuffix string) error {
err := x.validateParameters()
if err != nil {
return err
}
x.Transport += networkSuffix
if err = x.netConnect(); err != nil {
return fmt.Errorf("error establishing connection to host: %w", err)
}
if x.random == 0 {
n, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt32)) // returns a uniform random value in [0, 2147483647].
if err != nil {
return fmt.Errorf("error occurred while generating random: %w", err)
}
x.random = uint32(n.Uint64()) //nolint:gosec
}
// http://tools.ietf.org/html/rfc3412#section-6 - msgID only uses the first 31 bits
// msgID INTEGER (0..2147483647)
x.msgID = x.random
// RequestID is Integer32 from SNMPV2-SMI and uses all 32 bits
x.requestID = x.random
x.rxBuf = new([rxBufSize]byte)
return nil
}
// Performs the real socket opening network operation. This can be used to do a
// reconnect (needed for TCP)
func (x *GoSNMP) netConnect() error {
var err error
var localAddr net.Addr
addr := net.JoinHostPort(x.Target, strconv.Itoa(int(x.Port)))
switch x.Transport {
case "udp", "udp4", "udp6":
if localAddr, err = net.ResolveUDPAddr(x.Transport, x.LocalAddr); err != nil {
return err
}
if addr4 := localAddr.(*net.UDPAddr).IP.To4(); addr4 != nil {
x.Transport = "udp4"
}
if x.UseUnconnectedUDPSocket {
x.uaddr, err = net.ResolveUDPAddr(x.Transport, addr)
if err != nil {
return err
}
x.Conn, err = net.ListenUDP(x.Transport, localAddr.(*net.UDPAddr))
return err
}
case "tcp", "tcp4", "tcp6":
if localAddr, err = net.ResolveTCPAddr(x.Transport, x.LocalAddr); err != nil {
return err
}
if addr4 := localAddr.(*net.TCPAddr).IP.To4(); addr4 != nil {
x.Transport = "tcp4"
}
}
dialer := net.Dialer{Timeout: x.Timeout, LocalAddr: localAddr, Control: x.Control}
x.Conn, err = dialer.DialContext(x.Context, x.Transport, addr)
return err
}
func (x *GoSNMP) validateParameters() error {
if x.Transport == "" {
x.Transport = udp
}
if x.MaxOids == 0 {
x.MaxOids = MaxOids
} else if x.MaxOids < 0 {
return fmt.Errorf("field MaxOids cannot be less than 0")
}
if x.Version == Version3 {
// TODO: setting the Reportable flag violates rfc3412#6.4 if PDU is of type SNMPv2Trap.
// See if we can do this smarter and remove bitclear fix from trap.go:57
x.MsgFlags |= Reportable // tell the snmp server that a report PDU MUST be sent
err := x.validateParametersV3()
if err != nil {
return err
}
err = x.SecurityParameters.init(x.Logger)
if err != nil {
return err
}
}
if x.Context == nil {
x.Context = context.Background()
}
return nil
}
func (x *GoSNMP) MkSnmpPacket(pdutype PDUType, pdus []SnmpPDU, nonRepeaters uint8, maxRepetitions uint32) *SnmpPacket {
return x.mkSnmpPacket(pdutype, pdus, nonRepeaters, maxRepetitions)
}
func (x *GoSNMP) mkSnmpPacket(pdutype PDUType, pdus []SnmpPDU, nonRepeaters uint8, maxRepetitions uint32) *SnmpPacket {
var newSecParams SnmpV3SecurityParameters
if x.SecurityParameters != nil {
newSecParams = x.SecurityParameters.Copy()
}
return &SnmpPacket{
Version: x.Version,
Community: x.Community,
MsgFlags: x.MsgFlags,
SecurityModel: x.SecurityModel,
SecurityParameters: newSecParams,
ContextEngineID: x.ContextEngineID,
ContextName: x.ContextName,
Error: 0,
ErrorIndex: 0,
PDUType: pdutype,
NonRepeaters: nonRepeaters,
MaxRepetitions: (maxRepetitions & 0x7FFFFFFF),
Variables: pdus,
}
}
// Get sends an SNMP GET request
func (x *GoSNMP) Get(oids []string) (result *SnmpPacket, err error) {
oidCount := len(oids)
if oidCount > x.MaxOids {
return nil, fmt.Errorf("oid count (%d) is greater than MaxOids (%d)",
oidCount, x.MaxOids)
}
// convert oids slice to pdu slice
pdus := make([]SnmpPDU, 0, oidCount)
for _, oid := range oids {
pdus = append(pdus, SnmpPDU{Name: oid, Type: Null, Value: nil})
}
// build up SnmpPacket
packetOut := x.mkSnmpPacket(GetRequest, pdus, 0, 0)
return x.send(packetOut, true)
}
// Set sends an SNMP SET request
func (x *GoSNMP) Set(pdus []SnmpPDU) (result *SnmpPacket, err error) {
var packetOut *SnmpPacket
switch pdus[0].Type {
// TODO test Gauge32
case Integer, OctetString, Gauge32, IPAddress, ObjectIdentifier, Counter32, Counter64, Null, TimeTicks, Uinteger32, OpaqueFloat, OpaqueDouble:
packetOut = x.mkSnmpPacket(SetRequest, pdus, 0, 0)
default:
return nil, fmt.Errorf("ERR:gosnmp currently only supports SNMP SETs for Integer, OctetString, Gauge32, IPAddress, ObjectIdentifier, Counter32, Counter64, Null, TimeTicks, Uinteger32, OpaqueFloat, and OpaqueDouble. Not %s", pdus[0].Type)
}
return x.send(packetOut, true)
}
// GetNext sends an SNMP GETNEXT request
func (x *GoSNMP) GetNext(oids []string) (result *SnmpPacket, err error) {
oidCount := len(oids)
if oidCount > x.MaxOids {
return nil, fmt.Errorf("oid count (%d) is greater than MaxOids (%d)",
oidCount, x.MaxOids)
}
// convert oids slice to pdu slice
pdus := make([]SnmpPDU, 0, oidCount)
for _, oid := range oids {
pdus = append(pdus, SnmpPDU{Name: oid, Type: Null, Value: nil})
}
// Marshal and send the packet
packetOut := x.mkSnmpPacket(GetNextRequest, pdus, 0, 0)
return x.send(packetOut, true)
}
// GetBulk sends an SNMP GETBULK request
//
// For maxRepetitions greater than 255, use BulkWalk() or BulkWalkAll()
func (x *GoSNMP) GetBulk(oids []string, nonRepeaters uint8, maxRepetitions uint32) (result *SnmpPacket, err error) {
if x.Version == Version1 {
return nil, fmt.Errorf("GETBULK not supported in SNMPv1")
}
oidCount := len(oids)
if oidCount > x.MaxOids {
return nil, fmt.Errorf("oid count (%d) is greater than MaxOids (%d)",
oidCount, x.MaxOids)
}
// convert oids slice to pdu slice
pdus := make([]SnmpPDU, 0, oidCount)
for _, oid := range oids {
pdus = append(pdus, SnmpPDU{Name: oid, Type: Null, Value: nil})
}
// Marshal and send the packet
packetOut := x.mkSnmpPacket(GetBulkRequest, pdus, nonRepeaters, maxRepetitions)
return x.send(packetOut, true)
}
// SnmpEncodePacket exposes SNMP packet generation to external callers.
// This is useful for generating traffic for use over separate transport
// stacks and creating traffic samples for test purposes.
func (x *GoSNMP) SnmpEncodePacket(pdutype PDUType, pdus []SnmpPDU, nonRepeaters uint8, maxRepetitions uint32) ([]byte, error) {
err := x.validateParameters()
if err != nil {
return []byte{}, err
}
pkt := x.mkSnmpPacket(pdutype, pdus, nonRepeaters, maxRepetitions)
// Request ID is an atomic counter that wraps to 0 at max int32.
reqID := (atomic.AddUint32(&(x.requestID), 1) & 0x7FFFFFFF)
pkt.RequestID = reqID
if x.Version == Version3 {
msgID := (atomic.AddUint32(&(x.msgID), 1) & 0x7FFFFFFF)
pkt.MsgID = msgID
err = x.initPacket(pkt)
if err != nil {
return []byte{}, err
}
}
var out []byte
out, err = pkt.marshalMsg()
if err != nil {
return []byte{}, err
}
return out, nil
}
// SnmpDecodePacket exposes SNMP packet parsing to external callers.
// This is useful for processing traffic from other sources and
// building test harnesses.
func (x *GoSNMP) SnmpDecodePacket(resp []byte) (*SnmpPacket, error) {
var err error
result := &SnmpPacket{}
err = x.validateParameters()
if err != nil {
return result, err
}
result.Logger = x.Logger
if x.SecurityParameters != nil {
result.SecurityParameters = x.SecurityParameters.Copy()
}
var cursor int
cursor, err = x.unmarshalHeader(resp, result)
if err != nil {
err = fmt.Errorf("unable to decode packet header: %w", err)
return result, err
}
if result.Version == Version3 {
resp, cursor, err = x.decryptPacket(resp, cursor, result)
if err != nil {
return result, err
}
}
err = x.unmarshalPayload(resp, cursor, result)
if err != nil {
err = fmt.Errorf("unable to decode packet body: %w", err)
return result, err
}
return result, nil
}
// SetRequestID sets the base ID value for future requests
func (x *GoSNMP) SetRequestID(reqID uint32) {
x.requestID = reqID & 0x7fffffff
}
// SetMsgID sets the base ID value for future messages
func (x *GoSNMP) SetMsgID(msgID uint32) {
x.msgID = msgID & 0x7fffffff
}
//
// SNMP Walk functions - Analogous to net-snmp's snmpwalk commands
//
// WalkFunc is the type of the function called for each data unit visited
// by the Walk function. If an error is returned processing stops.
type WalkFunc func(dataUnit SnmpPDU) error
// BulkWalk retrieves a subtree of values using GETBULK. As the tree is
// walked walkFn is called for each new value. The function immediately returns
// an error if either there is an underlaying SNMP error (e.g. GetBulk fails),
// or if walkFn returns an error.
func (x *GoSNMP) BulkWalk(rootOid string, walkFn WalkFunc) error {
return x.walk(GetBulkRequest, rootOid, walkFn)
}
// BulkWalkAll is similar to BulkWalk but returns a filled array of all values
// rather than using a callback function to stream results. Caution: if you
// have set x.AppOpts to 'c', BulkWalkAll may loop indefinitely and cause an
// Out Of Memory - use BulkWalk instead.
func (x *GoSNMP) BulkWalkAll(rootOid string) (results []SnmpPDU, err error) {
return x.walkAll(GetBulkRequest, rootOid)
}
// Walk retrieves a subtree of values using GETNEXT - a request is made for each
// value, unlike BulkWalk which does this operation in batches. As the tree is
// walked walkFn is called for each new value. The function immediately returns
// an error if either there is an underlaying SNMP error (e.g. GetNext fails),
// or if walkFn returns an error.
func (x *GoSNMP) Walk(rootOid string, walkFn WalkFunc) error {
return x.walk(GetNextRequest, rootOid, walkFn)
}
// WalkAll is similar to Walk but returns a filled array of all values rather
// than using a callback function to stream results. Caution: if you have set
// x.AppOpts to 'c', WalkAll may loop indefinitely and cause an Out Of Memory -
// use Walk instead.
func (x *GoSNMP) WalkAll(rootOid string) (results []SnmpPDU, err error) {
return x.walkAll(GetNextRequest, rootOid)
}
//
// Public Functions (helpers) - in alphabetical order
//
// Partition - returns true when dividing a slice into
// partitionSize lengths, including last partition which may be smaller
// than partitionSize. This is useful when you have a large array of OIDs
// to run Get() on. See the tests for example usage.
//
// For example for a slice of 8 items to be broken into partitions of
// length 3, Partition returns true for the currentPosition having
// the following values:
//
// 0 1 2 3 4 5 6 7
//
// T T T
func Partition(currentPosition, partitionSize, sliceLength int) bool {
if currentPosition < 0 || currentPosition >= sliceLength {
return false
}
if partitionSize == 1 { // redundant, but an obvious optimisation
return true
}
if currentPosition%partitionSize == partitionSize-1 {
return true
}
if currentPosition == sliceLength-1 {
return true
}
return false
}
// ToBigInt converts SnmpPDU.Value to big.Int, or returns a zero big.Int for
// non int-like types (eg strings).
//
// This is a convenience function to make working with SnmpPDU's easier - it
// reduces the need for type assertions. A big.Int is convenient, as SNMP can
// return int32, uint32, and uint64.
func ToBigInt(value any) *big.Int {
var val int64
switch value := value.(type) { // shadow
case int:
val = int64(value)
case int8:
val = int64(value)
case int16:
val = int64(value)
case int32:
val = int64(value)
case int64:
val = value
case uint:
val = int64(value) //nolint:gosec
case uint8:
val = int64(value)
case uint16:
val = int64(value)
case uint32:
val = int64(value)
case uint64: // beware: int64(MaxUint64) overflow, handle different
return new(big.Int).SetUint64(value)
case string:
// for testing and other apps - numbers may appear as strings
var err error
if val, err = strconv.ParseInt(value, 10, 64); err != nil {
val = 0
}
default:
val = 0
}
return big.NewInt(val)
}
// Code generated by MockGen. DO NOT EDIT.
// Source: interface.go
// Package gosnmp is a generated GoMock package.
package gosnmp
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockHandler is a mock of Handler interface.
type MockHandler struct {
ctrl *gomock.Controller
recorder *MockHandlerMockRecorder
}
// MockHandlerMockRecorder is the mock recorder for MockHandler.
type MockHandlerMockRecorder struct {
mock *MockHandler
}
// NewMockHandler creates a new mock instance.
func NewMockHandler(ctrl *gomock.Controller) *MockHandler {
mock := &MockHandler{ctrl: ctrl}
mock.recorder = &MockHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockHandler) EXPECT() *MockHandlerMockRecorder {
return m.recorder
}
// BulkWalk mocks base method.
func (m *MockHandler) BulkWalk(rootOid string, walkFn WalkFunc) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BulkWalk", rootOid, walkFn)
ret0, _ := ret[0].(error)
return ret0
}
// BulkWalk indicates an expected call of BulkWalk.
func (mr *MockHandlerMockRecorder) BulkWalk(rootOid, walkFn interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkWalk", reflect.TypeOf((*MockHandler)(nil).BulkWalk), rootOid, walkFn)
}
// BulkWalkAll mocks base method.
func (m *MockHandler) BulkWalkAll(rootOid string) ([]SnmpPDU, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BulkWalkAll", rootOid)
ret0, _ := ret[0].([]SnmpPDU)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BulkWalkAll indicates an expected call of BulkWalkAll.
func (mr *MockHandlerMockRecorder) BulkWalkAll(rootOid interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkWalkAll", reflect.TypeOf((*MockHandler)(nil).BulkWalkAll), rootOid)
}
// Check mocks base method.
func (m *MockHandler) Check(err error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Check", err)
}
// Check indicates an expected call of Check.
func (mr *MockHandlerMockRecorder) Check(err interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockHandler)(nil).Check), err)
}
// Close mocks base method.
func (m *MockHandler) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockHandlerMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockHandler)(nil).Close))
}
// Community mocks base method.
func (m *MockHandler) Community() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Community")
ret0, _ := ret[0].(string)
return ret0
}
// Community indicates an expected call of Community.
func (mr *MockHandlerMockRecorder) Community() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Community", reflect.TypeOf((*MockHandler)(nil).Community))
}
// Connect mocks base method.
func (m *MockHandler) Connect() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect")
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockHandlerMockRecorder) Connect() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockHandler)(nil).Connect))
}
// ConnectIPv4 mocks base method.
func (m *MockHandler) ConnectIPv4() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectIPv4")
ret0, _ := ret[0].(error)
return ret0
}
// ConnectIPv4 indicates an expected call of ConnectIPv4.
func (mr *MockHandlerMockRecorder) ConnectIPv4() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectIPv4", reflect.TypeOf((*MockHandler)(nil).ConnectIPv4))
}
// ConnectIPv6 mocks base method.
func (m *MockHandler) ConnectIPv6() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConnectIPv6")
ret0, _ := ret[0].(error)
return ret0
}
// ConnectIPv6 indicates an expected call of ConnectIPv6.
func (mr *MockHandlerMockRecorder) ConnectIPv6() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectIPv6", reflect.TypeOf((*MockHandler)(nil).ConnectIPv6))
}
// ContextEngineID mocks base method.
func (m *MockHandler) ContextEngineID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ContextEngineID")
ret0, _ := ret[0].(string)
return ret0
}
// ContextEngineID indicates an expected call of ContextEngineID.
func (mr *MockHandlerMockRecorder) ContextEngineID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContextEngineID", reflect.TypeOf((*MockHandler)(nil).ContextEngineID))
}
// ContextName mocks base method.
func (m *MockHandler) ContextName() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ContextName")
ret0, _ := ret[0].(string)
return ret0
}
// ContextName indicates an expected call of ContextName.
func (mr *MockHandlerMockRecorder) ContextName() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContextName", reflect.TypeOf((*MockHandler)(nil).ContextName))
}
// Get mocks base method.
func (m *MockHandler) Get(oids []string) (*SnmpPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", oids)
ret0, _ := ret[0].(*SnmpPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockHandlerMockRecorder) Get(oids interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockHandler)(nil).Get), oids)
}
// GetBulk mocks base method.
func (m *MockHandler) GetBulk(oids []string, nonRepeaters uint8, maxRepetitions uint32) (*SnmpPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBulk", oids, nonRepeaters, maxRepetitions)
ret0, _ := ret[0].(*SnmpPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetBulk indicates an expected call of GetBulk.
func (mr *MockHandlerMockRecorder) GetBulk(oids, nonRepeaters, maxRepetitions interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBulk", reflect.TypeOf((*MockHandler)(nil).GetBulk), oids, nonRepeaters, maxRepetitions)
}
// GetExponentialTimeout mocks base method.
func (m *MockHandler) GetExponentialTimeout() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetExponentialTimeout")
ret0, _ := ret[0].(bool)
return ret0
}
// GetExponentialTimeout indicates an expected call of GetExponentialTimeout.
func (mr *MockHandlerMockRecorder) GetExponentialTimeout() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExponentialTimeout", reflect.TypeOf((*MockHandler)(nil).GetExponentialTimeout))
}
// GetNext mocks base method.
func (m *MockHandler) GetNext(oids []string) (*SnmpPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetNext", oids)
ret0, _ := ret[0].(*SnmpPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetNext indicates an expected call of GetNext.
func (mr *MockHandlerMockRecorder) GetNext(oids interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNext", reflect.TypeOf((*MockHandler)(nil).GetNext), oids)
}
// Logger mocks base method.
func (m *MockHandler) Logger() Logger {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logger")
ret0, _ := ret[0].(Logger)
return ret0
}
// Logger indicates an expected call of Logger.
func (mr *MockHandlerMockRecorder) Logger() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockHandler)(nil).Logger))
}
// MaxOids mocks base method.
func (m *MockHandler) MaxOids() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MaxOids")
ret0, _ := ret[0].(int)
return ret0
}
// MaxOids indicates an expected call of MaxOids.
func (mr *MockHandlerMockRecorder) MaxOids() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxOids", reflect.TypeOf((*MockHandler)(nil).MaxOids))
}
// MaxRepetitions mocks base method.
func (m *MockHandler) MaxRepetitions() uint32 {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MaxRepetitions")
ret0, _ := ret[0].(uint32)
return ret0
}
// MaxRepetitions indicates an expected call of MaxRepetitions.
func (mr *MockHandlerMockRecorder) MaxRepetitions() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxRepetitions", reflect.TypeOf((*MockHandler)(nil).MaxRepetitions))
}
// MsgFlags mocks base method.
func (m *MockHandler) MsgFlags() SnmpV3MsgFlags {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MsgFlags")
ret0, _ := ret[0].(SnmpV3MsgFlags)
return ret0
}
// MsgFlags indicates an expected call of MsgFlags.
func (mr *MockHandlerMockRecorder) MsgFlags() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MsgFlags", reflect.TypeOf((*MockHandler)(nil).MsgFlags))
}
// Port mocks base method.
func (m *MockHandler) Port() uint16 {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Port")
ret0, _ := ret[0].(uint16)
return ret0
}
// Port indicates an expected call of Port.
func (mr *MockHandlerMockRecorder) Port() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Port", reflect.TypeOf((*MockHandler)(nil).Port))
}
// Retries mocks base method.
func (m *MockHandler) Retries() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Retries")
ret0, _ := ret[0].(int)
return ret0
}
// Retries indicates an expected call of Retries.
func (mr *MockHandlerMockRecorder) Retries() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retries", reflect.TypeOf((*MockHandler)(nil).Retries))
}
// SecurityModel mocks base method.
func (m *MockHandler) SecurityModel() SnmpV3SecurityModel {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SecurityModel")
ret0, _ := ret[0].(SnmpV3SecurityModel)
return ret0
}
// SecurityModel indicates an expected call of SecurityModel.
func (mr *MockHandlerMockRecorder) SecurityModel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SecurityModel", reflect.TypeOf((*MockHandler)(nil).SecurityModel))
}
// SecurityParameters mocks base method.
func (m *MockHandler) SecurityParameters() SnmpV3SecurityParameters {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SecurityParameters")
ret0, _ := ret[0].(SnmpV3SecurityParameters)
return ret0
}
// SecurityParameters indicates an expected call of SecurityParameters.
func (mr *MockHandlerMockRecorder) SecurityParameters() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SecurityParameters", reflect.TypeOf((*MockHandler)(nil).SecurityParameters))
}
// SendTrap mocks base method.
func (m *MockHandler) SendTrap(trap SnmpTrap) (*SnmpPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendTrap", trap)
ret0, _ := ret[0].(*SnmpPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendTrap indicates an expected call of SendTrap.
func (mr *MockHandlerMockRecorder) SendTrap(trap interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendTrap", reflect.TypeOf((*MockHandler)(nil).SendTrap), trap)
}
// Set mocks base method.
func (m *MockHandler) Set(pdus []SnmpPDU) (*SnmpPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Set", pdus)
ret0, _ := ret[0].(*SnmpPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Set indicates an expected call of Set.
func (mr *MockHandlerMockRecorder) Set(pdus interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockHandler)(nil).Set), pdus)
}
// SetCommunity mocks base method.
func (m *MockHandler) SetCommunity(community string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetCommunity", community)
}
// SetCommunity indicates an expected call of SetCommunity.
func (mr *MockHandlerMockRecorder) SetCommunity(community interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCommunity", reflect.TypeOf((*MockHandler)(nil).SetCommunity), community)
}
// SetContextEngineID mocks base method.
func (m *MockHandler) SetContextEngineID(contextEngineID string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetContextEngineID", contextEngineID)
}
// SetContextEngineID indicates an expected call of SetContextEngineID.
func (mr *MockHandlerMockRecorder) SetContextEngineID(contextEngineID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetContextEngineID", reflect.TypeOf((*MockHandler)(nil).SetContextEngineID), contextEngineID)
}
// SetContextName mocks base method.
func (m *MockHandler) SetContextName(contextName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetContextName", contextName)
}
// SetContextName indicates an expected call of SetContextName.
func (mr *MockHandlerMockRecorder) SetContextName(contextName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetContextName", reflect.TypeOf((*MockHandler)(nil).SetContextName), contextName)
}
// SetExponentialTimeout mocks base method.
func (m *MockHandler) SetExponentialTimeout(value bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetExponentialTimeout", value)
}
// SetExponentialTimeout indicates an expected call of SetExponentialTimeout.
func (mr *MockHandlerMockRecorder) SetExponentialTimeout(value interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetExponentialTimeout", reflect.TypeOf((*MockHandler)(nil).SetExponentialTimeout), value)
}
// SetLogger mocks base method.
func (m *MockHandler) SetLogger(logger Logger) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLogger", logger)
}
// SetLogger indicates an expected call of SetLogger.
func (mr *MockHandlerMockRecorder) SetLogger(logger interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLogger", reflect.TypeOf((*MockHandler)(nil).SetLogger), logger)
}
// SetMaxOids mocks base method.
func (m *MockHandler) SetMaxOids(maxOids int) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetMaxOids", maxOids)
}
// SetMaxOids indicates an expected call of SetMaxOids.
func (mr *MockHandlerMockRecorder) SetMaxOids(maxOids interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxOids", reflect.TypeOf((*MockHandler)(nil).SetMaxOids), maxOids)
}
// SetMaxRepetitions mocks base method.
func (m *MockHandler) SetMaxRepetitions(maxRepetitions uint32) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetMaxRepetitions", maxRepetitions)
}
// SetMaxRepetitions indicates an expected call of SetMaxRepetitions.
func (mr *MockHandlerMockRecorder) SetMaxRepetitions(maxRepetitions interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxRepetitions", reflect.TypeOf((*MockHandler)(nil).SetMaxRepetitions), maxRepetitions)
}
// SetMsgFlags mocks base method.
func (m *MockHandler) SetMsgFlags(msgFlags SnmpV3MsgFlags) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetMsgFlags", msgFlags)
}
// SetMsgFlags indicates an expected call of SetMsgFlags.
func (mr *MockHandlerMockRecorder) SetMsgFlags(msgFlags interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMsgFlags", reflect.TypeOf((*MockHandler)(nil).SetMsgFlags), msgFlags)
}
// SetPort mocks base method.
func (m *MockHandler) SetPort(port uint16) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetPort", port)
}
// SetPort indicates an expected call of SetPort.
func (mr *MockHandlerMockRecorder) SetPort(port interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPort", reflect.TypeOf((*MockHandler)(nil).SetPort), port)
}
// SetRetries mocks base method.
func (m *MockHandler) SetRetries(retries int) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetRetries", retries)
}
// SetRetries indicates an expected call of SetRetries.
func (mr *MockHandlerMockRecorder) SetRetries(retries interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRetries", reflect.TypeOf((*MockHandler)(nil).SetRetries), retries)
}
// SetSecurityModel mocks base method.
func (m *MockHandler) SetSecurityModel(securityModel SnmpV3SecurityModel) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetSecurityModel", securityModel)
}
// SetSecurityModel indicates an expected call of SetSecurityModel.
func (mr *MockHandlerMockRecorder) SetSecurityModel(securityModel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSecurityModel", reflect.TypeOf((*MockHandler)(nil).SetSecurityModel), securityModel)
}
// SetSecurityParameters mocks base method.
func (m *MockHandler) SetSecurityParameters(securityParameters SnmpV3SecurityParameters) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetSecurityParameters", securityParameters)
}
// SetSecurityParameters indicates an expected call of SetSecurityParameters.
func (mr *MockHandlerMockRecorder) SetSecurityParameters(securityParameters interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSecurityParameters", reflect.TypeOf((*MockHandler)(nil).SetSecurityParameters), securityParameters)
}
// SetTarget mocks base method.
func (m *MockHandler) SetTarget(target string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTarget", target)
}
// SetTarget indicates an expected call of SetTarget.
func (mr *MockHandlerMockRecorder) SetTarget(target interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTarget", reflect.TypeOf((*MockHandler)(nil).SetTarget), target)
}
// SetTimeout mocks base method.
func (m *MockHandler) SetTimeout(timeout time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTimeout", timeout)
}
// SetTimeout indicates an expected call of SetTimeout.
func (mr *MockHandlerMockRecorder) SetTimeout(timeout interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTimeout", reflect.TypeOf((*MockHandler)(nil).SetTimeout), timeout)
}
// SetVersion mocks base method.
func (m *MockHandler) SetVersion(version SnmpVersion) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetVersion", version)
}
// SetVersion indicates an expected call of SetVersion.
func (mr *MockHandlerMockRecorder) SetVersion(version interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetVersion", reflect.TypeOf((*MockHandler)(nil).SetVersion), version)
}
// Target mocks base method.
func (m *MockHandler) Target() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Target")
ret0, _ := ret[0].(string)
return ret0
}
// Target indicates an expected call of Target.
func (mr *MockHandlerMockRecorder) Target() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Target", reflect.TypeOf((*MockHandler)(nil).Target))
}
// Timeout mocks base method.
func (m *MockHandler) Timeout() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Timeout")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// Timeout indicates an expected call of Timeout.
func (mr *MockHandlerMockRecorder) Timeout() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockHandler)(nil).Timeout))
}
// UnmarshalTrap mocks base method.
func (m *MockHandler) UnmarshalTrap(trap []byte, useResponseSecurityParameters bool) (*SnmpPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnmarshalTrap", trap, useResponseSecurityParameters)
ret0, _ := ret[0].(*SnmpPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UnmarshalTrap indicates an expected call of UnmarshalTrap.
func (mr *MockHandlerMockRecorder) UnmarshalTrap(trap, useResponseSecurityParameters interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmarshalTrap", reflect.TypeOf((*MockHandler)(nil).UnmarshalTrap), trap, useResponseSecurityParameters)
}
// Version mocks base method.
func (m *MockHandler) Version() SnmpVersion {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Version")
ret0, _ := ret[0].(SnmpVersion)
return ret0
}
// Version indicates an expected call of Version.
func (mr *MockHandlerMockRecorder) Version() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Version", reflect.TypeOf((*MockHandler)(nil).Version))
}
// Walk mocks base method.
func (m *MockHandler) Walk(rootOid string, walkFn WalkFunc) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Walk", rootOid, walkFn)
ret0, _ := ret[0].(error)
return ret0
}
// Walk indicates an expected call of Walk.
func (mr *MockHandlerMockRecorder) Walk(rootOid, walkFn interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Walk", reflect.TypeOf((*MockHandler)(nil).Walk), rootOid, walkFn)
}
// WalkAll mocks base method.
func (m *MockHandler) WalkAll(rootOid string) ([]SnmpPDU, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WalkAll", rootOid)
ret0, _ := ret[0].([]SnmpPDU)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// WalkAll indicates an expected call of WalkAll.
func (mr *MockHandlerMockRecorder) WalkAll(rootOid interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WalkAll", reflect.TypeOf((*MockHandler)(nil).WalkAll), rootOid)
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gosnmp
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"math"
"net"
"os"
"strconv"
)
// variable struct is used by decodeValue()
type variable struct {
Value any
Type Asn1BER
}
// helper error modes
var (
ErrBase128IntegerTooLarge = errors.New("base 128 integer too large")
ErrBase128IntegerTruncated = errors.New("base 128 integer truncated")
ErrFloatBufferTooShort = errors.New("float buffer too short")
ErrFloatTooLarge = errors.New("float too large")
ErrIntegerTooLarge = errors.New("integer too large")
ErrInvalidOidLength = errors.New("invalid OID length")
ErrInvalidPacketLength = errors.New("invalid packet length")
ErrZeroByteBuffer = errors.New("zero byte buffer")
ErrZeroLenInteger = errors.New("zero length integer")
)
// -- helper functions (mostly) in alphabetical order --------------------------
// Check makes checking errors easy, so they actually get a minimal check
func (x *GoSNMP) Check(err error) {
if err != nil {
x.Logger.Printf("Check: %v\n", err)
os.Exit(1)
}
}
// Check makes checking errors easy, so they actually get a minimal check
func (packet *SnmpPacket) Check(err error) {
if err != nil {
packet.Logger.Printf("Check: %v\n", err)
os.Exit(1)
}
}
// Check makes checking errors easy, so they actually get a minimal check
func Check(err error) {
if err != nil {
log.Fatalf("Check: %v\n", err)
}
}
func (x *GoSNMP) decodeValue(data []byte, retVal *variable) error {
if len(data) == 0 {
return ErrZeroByteBuffer
}
switch Asn1BER(data[0]) {
case Integer, Uinteger32:
// 0x02. signed
x.Logger.Printf("decodeValue: type is %s", Asn1BER(data[0]).String())
length, cursor, err := parseLength(data)
if err != nil {
return err
}
// check for truncated packets
if length > len(data) {
return fmt.Errorf("bytes: % x err: truncated (data %d length %d)", data, len(data), length)
}
var ret int
if ret, err = parseInt(data[cursor:length]); err != nil {
x.Logger.Printf("%v:", err)
return fmt.Errorf("bytes: % x err: %w", data, err)
}
retVal.Type = Asn1BER(data[0])
switch Asn1BER(data[0]) {
case Uinteger32:
retVal.Value = uint32(ret) //nolint:gosec
default:
retVal.Value = ret
}
case OctetString:
// 0x04
x.Logger.Print("decodeValue: type is OctetString")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
// check for truncated packet and throw an error
if length > len(data) {
return fmt.Errorf("bytes: % x err: truncated (data %d length %d)", data, len(data), length)
}
retVal.Type = OctetString
retVal.Value = data[cursor:length]
case Null:
// 0x05
x.Logger.Print("decodeValue: type is Null")
retVal.Type = Null
retVal.Value = nil
case ObjectIdentifier:
// 0x06
x.Logger.Print("decodeValue: type is ObjectIdentifier")
rawOid, _, err := parseRawField(x.Logger, data, "OID")
if err != nil {
return fmt.Errorf("error parsing OID Value: %w", err)
}
oid, ok := rawOid.(string)
if !ok {
return fmt.Errorf("unable to type assert rawOid |%v| to string", rawOid)
}
retVal.Type = ObjectIdentifier
retVal.Value = oid
case IPAddress:
// 0x40
x.Logger.Print("decodeValue: type is IPAddress")
retVal.Type = IPAddress
length, cursor, err := parseLength(data)
if err != nil {
return err
}
// length includes header bytes, ipLen is just the address bytes
ipLen := length - cursor
switch ipLen {
case 0: // real life, buggy devices returning bad data
retVal.Value = nil
return nil
case 4: // IPv4
if len(data) < cursor+4 {
return fmt.Errorf("not enough data for ipv4 address: %x", data)
}
retVal.Value = net.IP(data[cursor : cursor+4]).String()
case 16: // IPv6
if len(data) < cursor+16 {
return fmt.Errorf("not enough data for ipv6 address: %x", data)
}
d := make(net.IP, 16)
copy(d, data[cursor:cursor+16])
retVal.Value = d.String()
default:
return fmt.Errorf("got ipaddress len %d, expected 4 or 16", ipLen)
}
case Counter32:
// 0x41. unsigned
x.Logger.Print("decodeValue: type is Counter32")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
if length > len(data) {
return fmt.Errorf("not enough data for Counter32 %x (data %d length %d)", data, len(data), length)
}
ret, err := parseUint(data[cursor:length])
if err != nil {
x.Logger.Printf("decodeValue: err is %v", err)
break
}
retVal.Type = Counter32
retVal.Value = ret
case Gauge32:
// 0x42. unsigned
x.Logger.Print("decodeValue: type is Gauge32")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
if length > len(data) {
return fmt.Errorf("not enough data for Gauge32 %x (data %d length %d)", data, len(data), length)
}
ret, err := parseUint(data[cursor:length])
if err != nil {
x.Logger.Printf("decodeValue: err is %v", err)
break
}
retVal.Type = Gauge32
retVal.Value = ret
case TimeTicks:
// 0x43
x.Logger.Print("decodeValue: type is TimeTicks")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
if length > len(data) {
return fmt.Errorf("not enough data for TimeTicks %x (data %d length %d)", data, len(data), length)
}
ret, err := parseUint32(data[cursor:length])
if err != nil {
x.Logger.Printf("decodeValue: err is %v", err)
break
}
retVal.Type = TimeTicks
retVal.Value = ret
case Opaque:
// 0x44
x.Logger.Print("decodeValue: type is Opaque")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
if length > len(data) {
return fmt.Errorf("not enough data for Opaque %x (data %d length %d)", data, len(data), length)
}
return parseOpaque(x.Logger, data[cursor:length], retVal)
case Counter64:
// 0x46
x.Logger.Print("decodeValue: type is Counter64")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
if length > len(data) {
return fmt.Errorf("not enough data for Counter64 %x (data %d length %d)", data, len(data), length)
}
ret, err := parseUint64(data[cursor:length])
if err != nil {
x.Logger.Printf("decodeValue: err is %v", err)
break
}
retVal.Type = Counter64
retVal.Value = ret
case NoSuchObject:
// 0x80
x.Logger.Print("decodeValue: type is NoSuchObject")
retVal.Type = NoSuchObject
retVal.Value = nil
case NoSuchInstance:
// 0x81
x.Logger.Print("decodeValue: type is NoSuchInstance")
retVal.Type = NoSuchInstance
retVal.Value = nil
case EndOfMibView:
// 0x82
x.Logger.Print("decodeValue: type is EndOfMibView")
retVal.Type = EndOfMibView
retVal.Value = nil
default:
x.Logger.Printf("decodeValue: type %x isn't implemented", data[0])
retVal.Type = UnknownType
retVal.Value = nil
}
x.Logger.Printf("decodeValue: value is %#v", retVal.Value)
return nil
}
func marshalBase128Int(out io.ByteWriter, n int64) (err error) {
if n == 0 {
err = out.WriteByte(0)
return
}
l := 0
for i := n; i > 0; i >>= 7 {
l++
}
for i := l - 1; i >= 0; i-- {
o := byte(n >> uint(i*7)) //nolint:gosec
o &= 0x7f
if i != 0 {
o |= 0x80
}
err = out.WriteByte(o)
if err != nil {
return
}
}
return nil
}
/*
snmp Integer32 and INTEGER:
-2^31 and 2^31-1 inclusive (-2147483648 to 2147483647 decimal)
(FYI https://groups.google.com/forum/#!topic/comp.protocols.snmp/1xaAMzCe_hE)
versus:
snmp Counter32, Gauge32, TimeTicks, Unsigned32: (below)
non-negative integer, maximum value of 2^32-1 (4294967295 decimal)
*/
// marshalInt32 builds a byte representation of a signed 32 bit int in BigEndian form
// ie -2^31 and 2^31-1 inclusive (-2147483648 to 2147483647 decimal)
func marshalInt32(value int) ([]byte, error) {
if value < math.MinInt32 || value > math.MaxInt32 {
return nil, fmt.Errorf("unable to marshal: %d overflows int32", value)
}
const mask1 uint32 = 0xFFFFFF80
const mask2 uint32 = 0xFFFF8000
const mask3 uint32 = 0xFF800000
// const mask4 uint32 = 0x80000000
// ITU-T Rec. X.690 (2002) 8.3.2
// If the contents octets of an integer value encoding consist of more than
// one octet, then the bits of the first octet and bit 8 of the second octet:
// a) shall not all be ones; and
// b) shall not all be zero
// These rules ensure that an integer value is always encoded in the smallest
// possible number of octets.
val := uint32(value) //nolint:gosec
switch {
case val&mask1 == 0 || val&mask1 == mask1:
return []byte{byte(val)}, nil
case val&mask2 == 0 || val&mask2 == mask2:
return []byte{byte(val >> 8), byte(val)}, nil
case val&mask3 == 0 || val&mask3 == mask3:
return []byte{byte(val >> 16), byte(val >> 8), byte(val)}, nil
default:
return []byte{byte(val >> 24), byte(val >> 16), byte(val >> 8), byte(val)}, nil
}
}
// marshalUint64 encodes a uint64 into BER-compliant bytes for SNMP Counter64.
// It trims leading zero bytes and prepends one if MSB is set (per X.690 §8.3.2)
func marshalUint64(v any) ([]byte, error) {
// gracefully handle type assertion to uint64
source, ok := v.(uint64)
if !ok {
return nil, fmt.Errorf("marshalUint64: input is not a uint64")
}
// Step 1: Encode uint64 in big-endian (8 bytes)
bs := make([]byte, 8)
binary.BigEndian.PutUint64(bs, source)
// Step 2: Trim leading 0x00 bytes (X.690 §8.3.2: use minimal number of octets)
trimmed := bytes.TrimLeft(bs, "\x00")
// Step 3: Ensure at least one byte remains
if len(trimmed) == 0 {
return []byte{0}, nil
}
// Step 4: If the MSB of the first byte is set, prepend 0x00 to indicate positive value
if trimmed[0]&0x80 > 0 {
trimmed = append([]byte{0}, trimmed...)
}
return trimmed, nil
}
// Counter32, Gauge32, TimeTicks, Unsigned32, SNMPError
func marshalUint32(v any) ([]byte, error) {
var source uint32
switch val := v.(type) {
case uint32:
source = val
case uint:
source = uint32(val) //nolint:gosec
case uint8:
source = uint32(val)
case SNMPError:
source = uint32(val)
// We could do others here, but coercing from anything else is dangerous.
// Even uint could be 64 bits, though in practice nothing we work with is.
default:
return nil, fmt.Errorf("unable to marshal %T to uint32", v)
}
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, source)
var i int
for i = 0; i < 3; i++ {
if buf[i] != 0 {
break
}
}
buf = buf[i:]
// if the highest bit in buf is set and x is not negative - prepend a byte to make it positive
if len(buf) > 0 && buf[0]&0x80 > 0 {
buf = append([]byte{0}, buf...)
}
return buf, nil
}
func marshalFloat32(v any) ([]byte, error) {
source := v.(float32)
out := bytes.NewBuffer(nil)
err := binary.Write(out, binary.BigEndian, source)
return out.Bytes(), err
}
func marshalFloat64(v any) ([]byte, error) {
source := v.(float64)
out := bytes.NewBuffer(nil)
err := binary.Write(out, binary.BigEndian, source)
return out.Bytes(), err
}
// marshalLength builds a byte representation of length
//
// http://luca.ntop.org/Teaching/Appunti/asn1.html
//
// Length octets. There are two forms: short (for lengths between 0 and 127),
// and long definite (for lengths between 0 and 2^1008 -1).
//
// - Short form. One octet. Bit 8 has value "0" and bits 7-1 give the length.
// - Long form. Two to 127 octets. Bit 8 of first octet has value "1" and bits
// 7-1 give the number of additional length octets. Second and following
// octets give the length, base 256, most significant digit first.
func marshalLength(length int) ([]byte, error) {
// more convenient to pass length as int than uint64. Therefore check < 0
if length < 0 {
return nil, fmt.Errorf("length must be >= 0")
} else if length <= 127 {
return []byte{byte(length)}, nil
}
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint64(length))
if err != nil {
return nil, err
}
bufBytes := buf.Bytes()
// strip leading zeros
for idx, octect := range bufBytes {
if octect != 00 {
bufBytes = bufBytes[idx:]
break
}
}
header := []byte{byte(128 | len(bufBytes))}
return append(header, bufBytes...), nil
}
// marshalTLV writes a BER TLV (type-length-value) to buf using proper length
// encoding. Handles values of any size, including those exceeding 127 bytes.
func marshalTLV(buf *bytes.Buffer, tag byte, value []byte) error {
length, err := marshalLength(len(value))
if err != nil {
return err
}
buf.WriteByte(tag)
buf.Write(length)
buf.Write(value)
return nil
}
func marshalObjectIdentifier(oid string) ([]byte, error) {
out := new(bytes.Buffer)
oidLength := len(oid)
oidBase := 0
var err error
i := 0
for j := 0; j < oidLength; {
if oid[j] == '.' {
j++
continue
}
var val int64
for j < oidLength && oid[j] != '.' {
ch := int64(oid[j] - '0')
if ch > 9 {
return []byte{}, fmt.Errorf("unable to marshal OID: Invalid object identifier")
}
val *= 10
val += ch
j++
}
switch i {
case 0:
if val > 6 {
return []byte{}, fmt.Errorf("unable to marshal OID: Invalid object identifier")
}
oidBase = int(val * 40)
case 1:
if val >= 40 {
return []byte{}, fmt.Errorf("unable to marshal OID: Invalid object identifier")
}
oidBase += int(val)
err = out.WriteByte(byte(oidBase))
if err != nil {
return []byte{}, fmt.Errorf("unable to marshal OID: Invalid object identifier")
}
default:
if val > MaxObjectSubIdentifierValue {
return []byte{}, fmt.Errorf("unable to marshal OID: Value out of range")
}
err = marshalBase128Int(out, val)
if err != nil {
return []byte{}, fmt.Errorf("unable to marshal OID: Invalid object identifier")
}
}
i++
}
if i < 2 || i > 128 {
return []byte{}, fmt.Errorf("unable to marshal OID: Invalid object identifier")
}
return out.Bytes(), nil
}
// TODO no tests
func ipv4toBytes(ip net.IP) []byte {
return []byte(ip)[12:]
}
// parseOpaque parses a Opaque encoded data
// Known data-types is OpaqueDouble and OpaqueFloat
// Other data decoded as binary Opaque data
// TODO: add OpaqueCounter64 (0x76), OpaqueInteger64 (0x80), OpaqueUinteger64 (0x81)
func parseOpaque(logger Logger, data []byte, retVal *variable) error {
if len(data) == 0 {
return ErrZeroByteBuffer
}
if len(data) > 2 && data[0] == AsnExtensionTag {
switch Asn1BER(data[1]) {
case OpaqueDouble:
// 0x79
data = data[1:]
logger.Print("decodeValue: type is OpaqueDouble")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
if length > len(data) {
return fmt.Errorf("not enough data for OpaqueDouble %x (data %d length %d)", data, len(data), length)
}
retVal.Type = OpaqueDouble
retVal.Value, err = parseFloat64(data[cursor:length])
if err != nil {
return err
}
case OpaqueFloat:
// 0x78
data = data[1:]
logger.Print("decodeValue: type is OpaqueFloat")
length, cursor, err := parseLength(data)
if err != nil {
return err
}
if length > len(data) {
return fmt.Errorf("not enough data for OpaqueFloat %x (data %d length %d)", data, len(data), length)
}
if cursor > length {
return fmt.Errorf("invalid cursor position for OpaqueFloat %x (data %d length %d cursor %d)", data, len(data), length, cursor)
}
retVal.Type = OpaqueFloat
retVal.Value, err = parseFloat32(data[cursor:length])
if err != nil {
return err
}
default:
logger.Print("decodeValue: type is Opaque")
retVal.Type = Opaque
retVal.Value = data[0:]
}
} else {
logger.Print("decodeValue: type is Opaque")
retVal.Type = Opaque
retVal.Value = data[0:]
}
return nil
}
// parseBase128Uint32 parses a base-128 encoded unsigned integer from the given
// offset in the given byte slice. Returns the value and the new offset.
func parseBase128Uint32(bytes []byte, initOffset int) (uint32, int, error) {
var ret uint64
offset := initOffset
for offset < len(bytes) {
b := bytes[offset]
offset++
ret = (ret << 7) | uint64(b&0x7f)
if ret > math.MaxUint32 {
return 0, 0, ErrBase128IntegerTooLarge
}
if b&0x80 == 0 {
return uint32(ret), offset, nil
}
}
return 0, 0, ErrBase128IntegerTruncated
}
// parseInt64 treats the given bytes as a big-endian, signed integer and
// returns the result.
func parseInt64(bytes []byte) (int64, error) {
switch {
case len(bytes) == 0:
// X.690 8.3.1: Encoding of an integer value:
// The encoding of an integer value shall be primitive.
// The contents octets shall consist of one or more octets.
return 0, ErrZeroLenInteger
case len(bytes) > 8:
// We'll overflow an int64 in this case.
return 0, ErrIntegerTooLarge
}
var ret int64
for bytesRead := range bytes {
ret <<= 8
ret |= int64(bytes[bytesRead])
}
// Shift up and down in order to sign extend the result.
ret <<= 64 - uint8(len(bytes))*8 //nolint:gosec
ret >>= 64 - uint8(len(bytes))*8 //nolint:gosec
return ret, nil
}
// parseInt treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseInt(bytes []byte) (int, error) {
ret64, err := parseInt64(bytes)
if err != nil {
return 0, err
}
if ret64 != int64(int(ret64)) {
return 0, ErrIntegerTooLarge
}
return int(ret64), nil
}
// parseLength parses and calculates an snmp packet length
// and returns an error when invalid data is detected
//
// http://luca.ntop.org/Teaching/Appunti/asn1.html
//
// Length octets. There are two forms: short (for lengths between 0 and 127),
// and long definite (for lengths between 0 and 2^1008 -1).
//
// - Short form. One octet. Bit 8 has value "0" and bits 7-1 give the length.
// - Long form. Two to 127 octets. Bit 8 of first octet has value "1" and bits
// 7-1 give the number of additional length octets. Second and following
// octets give the length, base 256, most significant digit first.
func parseLength(bytes []byte) (int, int, error) {
var cursor, length int
switch {
case len(bytes) < 2:
// handle null octet strings ie "0x04 0x00"
cursor = len(bytes)
length = len(bytes)
case int(bytes[1]) <= 127:
length = int(bytes[1])
length += 2
cursor += 2
case bytes[1] == 0x80:
// Indefinite length encoding (0x80) is prohibited in SNMP per RFC 3417 Section 8:
// "When encoding the length field, only the definite form is used;
// use of the indefinite form encoding is prohibited."
return 0, 0, fmt.Errorf("indefinite length encoding (0x80) is not permitted in SNMP")
default:
numOctets := int(bytes[1]) & 127
for i := range numOctets {
length <<= 8
if len(bytes) < 2+i+1 {
// Invalid data detected, return an error
return 0, 0, ErrInvalidPacketLength
}
length += int(bytes[2+i])
if length < 0 {
// Invalid length due to overflow, return an error
return 0, 0, ErrInvalidPacketLength
}
}
length += 2 + numOctets
cursor += 2 + numOctets
}
if length < 0 {
// Invalid data detected, return an error
return 0, 0, ErrInvalidPacketLength
}
return length, cursor, nil
}
// parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and
// returns it. An object identifier is a sequence of variable length integers
// that are assigned in a hierarchy.
func parseObjectIdentifier(src []byte) (string, error) {
if len(src) == 0 {
return "", ErrInvalidOidLength
}
// Worst-case: first byte expands to 5 chars (".2.39"), rest to 4 chars (".127")
out := make([]byte, 0, len(src)*4+1)
out = append(out, '.')
out = strconv.AppendUint(out, uint64(src[0]/40), 10)
out = append(out, '.')
out = strconv.AppendUint(out, uint64(src[0]%40), 10)
var v uint32
var err error
for offset := 1; offset < len(src); {
out = append(out, '.')
v, offset, err = parseBase128Uint32(src, offset)
if err != nil {
return "", err
}
out = strconv.AppendUint(out, uint64(v), 10)
}
return string(out), nil
}
func parseRawField(logger Logger, data []byte, msg string) (any, int, error) {
if len(data) == 0 {
return nil, 0, fmt.Errorf("empty data passed to parseRawField")
}
logger.Printf("parseRawField: %s", msg)
switch Asn1BER(data[0]) {
case Integer:
length, cursor, err := parseLength(data)
if err != nil {
return nil, 0, err
}
if length > len(data) {
return nil, 0, fmt.Errorf("not enough data for Integer (%d vs %d): %x", length, len(data), data)
}
if cursor > length {
return nil, 0, fmt.Errorf("invalid cursor position for Integer %x (data %d length %d cursor %d)", data, len(data), length, cursor)
}
i, err := parseInt(data[cursor:length])
if err != nil {
return nil, 0, fmt.Errorf("unable to parse raw INTEGER: %x err: %w", data, err)
}
return i, length, nil
case OctetString:
length, cursor, err := parseLength(data)
if err != nil {
return nil, 0, err
}
if length > len(data) {
return nil, 0, fmt.Errorf("not enough data for OctetString (%d vs %d): %x", length, len(data), data)
}
if cursor > length {
return nil, 0, fmt.Errorf("invalid cursor position for OctetString %x (data %d length %d cursor %d)", data, len(data), length, cursor)
}
return string(data[cursor:length]), length, nil
case ObjectIdentifier:
length, cursor, err := parseLength(data)
if err != nil {
return nil, 0, err
}
if length > len(data) {
return nil, 0, fmt.Errorf("not enough data for OID (%d vs %d): %x", length, len(data), data)
}
if cursor > length {
return nil, 0, fmt.Errorf("invalid cursor position for OID %x (data %d length %d cursor %d)", data, len(data), length, cursor)
}
oid, err := parseObjectIdentifier(data[cursor:length])
return oid, length, err
case IPAddress:
length, cursor, err := parseLength(data)
if err != nil {
return nil, 0, err
}
// length includes header bytes, ipLen is just the address bytes
ipLen := length - cursor
switch ipLen {
case 0: // real life, buggy devices returning bad data
return nil, length, nil
case 4: // IPv4
if len(data) < cursor+4 {
return nil, 0, fmt.Errorf("not enough data for ipv4 address: %x", data)
}
return net.IP(data[cursor : cursor+4]).String(), length, nil
default:
return nil, 0, fmt.Errorf("got ipaddress len %d, expected 4", ipLen)
}
case TimeTicks:
length, cursor, err := parseLength(data)
if err != nil {
return nil, 0, err
}
if length > len(data) {
return nil, 0, fmt.Errorf("not enough data for TimeTicks (%d vs %d): %x", length, len(data), data)
}
if cursor > length {
return nil, 0, fmt.Errorf("invalid cursor position for TimeTicks %x (data %d length %d cursor %d)", data, len(data), length, cursor)
}
ret, err := parseUint(data[cursor:length])
if err != nil {
return nil, 0, fmt.Errorf("error in parseUint: %w", err)
}
return ret, length, nil
}
return nil, 0, fmt.Errorf("unknown field type: %x", data[0])
}
// parseUint64 treats the given bytes as a big-endian, unsigned integer and returns
// the result.
func parseUint64(bytes []byte) (uint64, error) {
var ret uint64
if len(bytes) > 9 || (len(bytes) > 8 && bytes[0] != 0x0) {
// We'll overflow a uint64 in this case.
return 0, ErrIntegerTooLarge
}
for bytesRead := range bytes {
ret <<= 8
ret |= uint64(bytes[bytesRead])
}
return ret, nil
}
// parseUint32 treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseUint32(bytes []byte) (uint32, error) {
ret, err := parseUint(bytes)
if err != nil {
return 0, err
}
return uint32(ret), nil //nolint:gosec
}
// parseUint treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseUint(bytes []byte) (uint, error) {
ret64, err := parseUint64(bytes)
if err != nil {
return 0, err
}
if ret64 != uint64(uint(ret64)) {
return 0, ErrIntegerTooLarge
}
return uint(ret64), nil
}
func parseFloat32(bytes []byte) (float32, error) {
if len(bytes) > 4 {
// We'll overflow a uint64 in this case.
return 0, ErrFloatTooLarge
}
if len(bytes) < 4 {
// We'll cause a panic in binary.BigEndian.Uint32() in this case
return 0, ErrFloatBufferTooShort
}
return math.Float32frombits(binary.BigEndian.Uint32(bytes)), nil
}
func parseFloat64(bytes []byte) (float64, error) {
if len(bytes) > 8 {
// We'll overflow a uint64 in this case.
return 0, ErrFloatTooLarge
}
if len(bytes) < 8 {
// We'll cause a panic in binary.BigEndian.Uint64() in this case
return 0, ErrFloatBufferTooShort
}
return math.Float64frombits(binary.BigEndian.Uint64(bytes)), nil
}
// -- Bit String ---------------------------------------------------------------
// BitStringValue is the structure to use when you want an ASN.1 BIT STRING type. A
// bit string is padded up to the nearest byte in memory and the number of
// valid bits is recorded. Padding bits will be zero.
type BitStringValue struct {
Bytes []byte // bits packed into bytes.
BitLength int // length in bits.
}
// At returns the bit at the given index. If the index is out of range it
// returns false.
func (b BitStringValue) At(i int) int {
if i < 0 || i >= b.BitLength {
return 0
}
x := i / 8
y := 7 - uint(i%8) //nolint:gosec
return int(b.Bytes[x]>>y) & 1
}
// RightAlign returns a slice where the padding bits are at the beginning. The
// slice may share memory with the BitString.
func (b BitStringValue) RightAlign() []byte {
shift := uint(8 - (b.BitLength % 8)) //nolint:gosec
if shift == 8 || len(b.Bytes) == 0 {
return b.Bytes
}
a := make([]byte, len(b.Bytes))
a[0] = b.Bytes[0] >> shift
for i := 1; i < len(b.Bytes); i++ {
a[i] = b.Bytes[i-1] << (8 - shift)
a[i] |= b.Bytes[i] >> shift
}
return a
}
// -- SnmpVersion --------------------------------------------------------------
func (s SnmpVersion) String() string {
switch s {
case Version1:
return "1"
case Version2c:
return "2c"
case Version3:
return "3"
default:
return "3"
}
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gosnmp
import (
"time"
)
//go:generate mockgen --destination gosnmp_mock.go --package=gosnmp --source interface.go
// Handler is a GoSNMP interface
//
// Handler is provided to assist with testing using mocks
type Handler interface {
// Connect creates and opens a socket. Because UDP is a connectionless
// protocol, you won't know if the remote host is responding until you send
// packets. And if the host is regularly disappearing and reappearing, you won't
// know if you've only done a Connect().
//
// For historical reasons (ie this is part of the public API), the method won't
// be renamed.
Connect() error
// ConnectIPv4 connects using IPv4
ConnectIPv4() error
// ConnectIPv6 connects using IPv6
ConnectIPv6() error
// Get sends an SNMP GET request
Get(oids []string) (result *SnmpPacket, err error)
// GetBulk sends an SNMP GETBULK request
GetBulk(oids []string, nonRepeaters uint8, maxRepetitions uint32) (result *SnmpPacket, err error)
// GetNext sends an SNMP GETNEXT request
GetNext(oids []string) (result *SnmpPacket, err error)
// Walk retrieves a subtree of values using GETNEXT - a request is made for each
// value, unlike BulkWalk which does this operation in batches. As the tree is
// walked walkFn is called for each new value. The function immediately returns
// an error if either there is an underlaying SNMP error (e.g. GetNext fails),
// or if walkFn returns an error.
Walk(rootOid string, walkFn WalkFunc) error
// WalkAll is similar to Walk but returns a filled array of all values rather
// than using a callback function to stream results.
WalkAll(rootOid string) (results []SnmpPDU, err error)
// BulkWalk retrieves a subtree of values using GETBULK. As the tree is
// walked walkFn is called for each new value. The function immediately returns
// an error if either there is an underlaying SNMP error (e.g. GetBulk fails),
// or if walkFn returns an error.
BulkWalk(rootOid string, walkFn WalkFunc) error
// BulkWalkAll is similar to BulkWalk but returns a filled array of all values
// rather than using a callback function to stream results.
BulkWalkAll(rootOid string) (results []SnmpPDU, err error)
// SendTrap sends a SNMP Trap (v2c/v3 only)
//
// pdus[0] can a pdu of Type TimeTicks (with the desired uint32 epoch
// time). Otherwise a TimeTicks pdu will be prepended, with time set to
// now. This mirrors the behaviour of the Net-SNMP command-line tools.
//
// SendTrap doesn't wait for a return packet from the NMS (Network
// Management Station).
//
// See also Listen() and examples for creating an NMS.
SendTrap(trap SnmpTrap) (result *SnmpPacket, err error)
// UnmarshalTrap unpacks the SNMP Trap.
UnmarshalTrap(trap []byte, useResponseSecurityParameters bool) (result *SnmpPacket, err error)
// Set sends an SNMP SET request
Set(pdus []SnmpPDU) (result *SnmpPacket, err error)
// Check makes checking errors easy, so they actually get a minimal check
Check(err error)
// Close closes the connection
Close() error
// Target gets the Target
Target() string
// SetTarget sets the Target
SetTarget(target string)
// Port gets the Port
Port() uint16
// SetPort sets the Port
SetPort(port uint16)
// Community gets the Community
Community() string
// SetCommunity sets the Community
SetCommunity(community string)
// Version gets the Version
Version() SnmpVersion
// SetVersion sets the Version
SetVersion(version SnmpVersion)
// Timeout gets the Timeout
Timeout() time.Duration
// SetTimeout sets the Timeout
SetTimeout(timeout time.Duration)
// Retries gets the Retries
Retries() int
// SetRetries sets the Retries
SetRetries(retries int)
// GetExponentialTimeout gets the ExponentialTimeout
GetExponentialTimeout() bool
// SetExponentialTimeout sets the ExponentialTimeout
SetExponentialTimeout(value bool)
// Logger gets the Logger
Logger() Logger
// SetLogger sets the Logger
SetLogger(logger Logger)
// MaxOids gets the MaxOids
MaxOids() int
// SetMaxOids sets the MaxOids
SetMaxOids(maxOids int)
// MaxRepetitions gets the maxRepetitions
MaxRepetitions() uint32
// SetMaxRepetitions sets the maxRepetitions
SetMaxRepetitions(maxRepetitions uint32)
// MsgFlags gets the MsgFlags
MsgFlags() SnmpV3MsgFlags
// SetMsgFlags sets the MsgFlags
SetMsgFlags(msgFlags SnmpV3MsgFlags)
// SecurityModel gets the SecurityModel
SecurityModel() SnmpV3SecurityModel
// SetSecurityModel sets the SecurityModel
SetSecurityModel(securityModel SnmpV3SecurityModel)
// SecurityParameters gets the SecurityParameters
SecurityParameters() SnmpV3SecurityParameters
// SetSecurityParameters sets the SecurityParameters
SetSecurityParameters(securityParameters SnmpV3SecurityParameters)
// ContextEngineID gets the ContextEngineID
ContextEngineID() string
// SetContextEngineID sets the ContextEngineID
SetContextEngineID(contextEngineID string)
// ContextName gets the ContextName
ContextName() string
// SetContextName sets the ContextName
SetContextName(contextName string)
}
// snmpHandler is a wrapper around gosnmp
type snmpHandler struct {
GoSNMP
}
// NewHandler creates a new Handler using gosnmp
func NewHandler() Handler {
return &snmpHandler{
GoSNMP{
Port: Default.Port,
Community: Default.Community,
Version: Default.Version,
Timeout: Default.Timeout,
Retries: Default.Retries,
MaxOids: Default.MaxOids,
},
}
}
func (x *snmpHandler) Target() string {
// not x.Target because it would reference function Target
return x.GoSNMP.Target
}
func (x *snmpHandler) SetTarget(target string) {
x.GoSNMP.Target = target
}
func (x *snmpHandler) Port() uint16 {
return x.GoSNMP.Port
}
func (x *snmpHandler) SetPort(port uint16) {
x.GoSNMP.Port = port
}
func (x *snmpHandler) Community() string {
return x.GoSNMP.Community
}
func (x *snmpHandler) SetCommunity(community string) {
x.GoSNMP.Community = community
}
func (x *snmpHandler) Version() SnmpVersion {
return x.GoSNMP.Version
}
func (x *snmpHandler) SetVersion(version SnmpVersion) {
x.GoSNMP.Version = version
}
func (x *snmpHandler) Timeout() time.Duration {
return x.GoSNMP.Timeout
}
func (x *snmpHandler) SetTimeout(timeout time.Duration) {
x.GoSNMP.Timeout = timeout
}
func (x *snmpHandler) Retries() int {
return x.GoSNMP.Retries
}
func (x *snmpHandler) SetRetries(retries int) {
x.GoSNMP.Retries = retries
}
func (x *snmpHandler) GetExponentialTimeout() bool {
return x.ExponentialTimeout
}
func (x *snmpHandler) SetExponentialTimeout(value bool) {
x.ExponentialTimeout = value
}
func (x *snmpHandler) Logger() Logger {
return x.GoSNMP.Logger
}
func (x *snmpHandler) SetLogger(logger Logger) {
x.GoSNMP.Logger = logger
}
func (x *snmpHandler) MaxOids() int {
return x.GoSNMP.MaxOids
}
func (x *snmpHandler) SetMaxOids(maxOids int) {
x.GoSNMP.MaxOids = maxOids
}
func (x *snmpHandler) MaxRepetitions() uint32 {
return (x.GoSNMP.MaxRepetitions & 0x7FFFFFFF)
}
// SetMaxRepetitions wraps to 0 at max int32.
func (x *snmpHandler) SetMaxRepetitions(maxRepetitions uint32) {
x.GoSNMP.MaxRepetitions = (maxRepetitions & 0x7FFFFFFF)
}
func (x *snmpHandler) MsgFlags() SnmpV3MsgFlags {
return x.GoSNMP.MsgFlags
}
func (x *snmpHandler) SetMsgFlags(msgFlags SnmpV3MsgFlags) {
x.GoSNMP.MsgFlags = msgFlags
}
func (x *snmpHandler) SecurityModel() SnmpV3SecurityModel {
return x.GoSNMP.SecurityModel
}
func (x *snmpHandler) SetSecurityModel(securityModel SnmpV3SecurityModel) {
x.GoSNMP.SecurityModel = securityModel
}
func (x *snmpHandler) SecurityParameters() SnmpV3SecurityParameters {
return x.GoSNMP.SecurityParameters
}
func (x *snmpHandler) SetSecurityParameters(securityParameters SnmpV3SecurityParameters) {
x.GoSNMP.SecurityParameters = securityParameters
}
func (x *snmpHandler) ContextEngineID() string {
return x.GoSNMP.ContextEngineID
}
func (x *snmpHandler) SetContextEngineID(contextEngineID string) {
x.GoSNMP.ContextEngineID = contextEngineID
}
func (x *snmpHandler) ContextName() string {
return x.GoSNMP.ContextName
}
func (x *snmpHandler) SetContextName(contextName string) {
x.GoSNMP.ContextName = contextName
}
func (x *snmpHandler) Close() error {
// not x.Conn for consistency
return x.Conn.Close()
}
// Copyright 2021 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build !gosnmp_nodebug
package gosnmp
func (l *Logger) Print(v ...any) {
if l.logger != nil {
l.logger.Print(v...)
}
}
func (l *Logger) Printf(format string, v ...any) {
if l.logger != nil {
l.logger.Printf(format, v...)
}
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
package gosnmp
import (
"bytes"
"context"
"encoding/asn1"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"runtime"
"strings"
"sync/atomic"
"time"
)
//
// Remaining globals and definitions located here.
// See http://www.rane.com/note161.html for a succint description of the SNMP
// protocol.
//
// SnmpVersion 1, 2c and 3 implemented
type SnmpVersion uint8
// SnmpVersion 1, 2c and 3 implemented
const (
Version1 SnmpVersion = 0x0
Version2c SnmpVersion = 0x1
Version3 SnmpVersion = 0x3
)
// SnmpPacket struct represents the entire SNMP Message or Sequence at the
// application layer.
type SnmpPacket struct {
Version SnmpVersion
MsgFlags SnmpV3MsgFlags
SecurityModel SnmpV3SecurityModel
SecurityParameters SnmpV3SecurityParameters // interface
ContextEngineID string
ContextName string
Community string
PDUType PDUType
MsgID uint32
RequestID uint32
MsgMaxSize uint32
Error SNMPError
ErrorIndex uint8
NonRepeaters uint8
MaxRepetitions uint32
Variables []SnmpPDU
Logger Logger
// v1 traps have a very different format from v2c and v3 traps.
//
// These fields are set via the SnmpTrap parameter to SendTrap().
SnmpTrap
}
// SnmpTrap is used to define a SNMP trap, and is passed into SendTrap
type SnmpTrap struct {
Variables []SnmpPDU
// If true, the trap is an InformRequest, not a trap. This has no effect on
// v1 traps, as Inform is not part of the v1 protocol.
IsInform bool
// These fields are required for SNMPV1 Trap Headers
Enterprise string
AgentAddress string
GenericTrap int
SpecificTrap int
Timestamp uint
}
// VarBind struct represents an SNMP Varbind.
type VarBind struct {
Name asn1.ObjectIdentifier
Value asn1.RawValue
}
// PDUType describes which SNMP Protocol Data Unit is being sent.
type PDUType byte
// The currently supported PDUType's
const (
Sequence PDUType = 0x30
GetRequest PDUType = 0xa0
GetNextRequest PDUType = 0xa1
GetResponse PDUType = 0xa2
SetRequest PDUType = 0xa3
Trap PDUType = 0xa4 // v1
GetBulkRequest PDUType = 0xa5
InformRequest PDUType = 0xa6
SNMPv2Trap PDUType = 0xa7 // v2c, v3
Report PDUType = 0xa8 // v3
)
//go:generate stringer -type=PDUType
// SNMPv3: User-based Security Model Report PDUs and
// error types as per https://tools.ietf.org/html/rfc3414
const (
usmStatsUnsupportedSecLevels = ".1.3.6.1.6.3.15.1.1.1.0"
usmStatsNotInTimeWindows = ".1.3.6.1.6.3.15.1.1.2.0"
usmStatsUnknownUserNames = ".1.3.6.1.6.3.15.1.1.3.0"
usmStatsUnknownEngineIDs = ".1.3.6.1.6.3.15.1.1.4.0"
usmStatsWrongDigests = ".1.3.6.1.6.3.15.1.1.5.0"
usmStatsDecryptionErrors = ".1.3.6.1.6.3.15.1.1.6.0"
snmpUnknownSecurityModels = ".1.3.6.1.6.3.11.2.1.1.0"
snmpInvalidMsgs = ".1.3.6.1.6.3.11.2.1.2.0"
snmpUnknownPDUHandlers = ".1.3.6.1.6.3.11.2.1.3.0"
)
var (
ErrDecryption = errors.New("decryption error")
ErrInvalidMsgs = errors.New("invalid messages")
ErrNotInTimeWindow = errors.New("not in time window")
ErrUnknownEngineID = errors.New("unknown engine id")
ErrUnknownPDUHandlers = errors.New("unknown pdu handlers")
ErrUnknownReportPDU = errors.New("unknown report pdu")
ErrUnknownSecurityLevel = errors.New("unknown security level")
ErrUnknownSecurityModels = errors.New("unknown security models")
ErrUnknownUsername = errors.New("unknown username")
ErrWrongDigest = errors.New("wrong digest")
)
const rxBufSize = 65535 // max size of IPv4 & IPv6 packet
// Logger is an interface used for debugging. Both Print and
// Printf have the same interfaces as Package Log in the std library. The
// Logger interface is small to give you flexibility in how you do
// your debugging.
//
// Logger
// For verbose logging to stdout:
// gosnmp_logger = NewLogger(log.New(os.Stdout, "", 0))
type LoggerInterface interface {
Print(v ...any)
Printf(format string, v ...any)
}
type Logger struct {
logger LoggerInterface
}
func NewLogger(logger LoggerInterface) Logger {
return Logger{
logger: logger,
}
}
func (packet *SnmpPacket) SafeString() string {
sp := ""
if packet.SecurityParameters != nil {
sp = packet.SecurityParameters.SafeString()
}
return fmt.Sprintf("Version:%s, MsgFlags:%s, SecurityModel:%s, SecurityParameters:%s, ContextEngineID:%s, ContextName:%s, Community:%s, PDUType:%s, MsgID:%d, RequestID:%d, MsgMaxSize:%d, Error:%s, ErrorIndex:%d, NonRepeaters:%d, MaxRepetitions:%d, Variables:%v",
packet.Version,
packet.MsgFlags,
packet.SecurityModel,
sp,
packet.ContextEngineID,
packet.ContextName,
packet.Community,
packet.PDUType,
packet.MsgID,
packet.RequestID,
packet.MsgMaxSize,
packet.Error,
packet.ErrorIndex,
packet.NonRepeaters,
packet.MaxRepetitions,
packet.Variables,
)
}
// GoSNMP
// send/receive one snmp request
func (x *GoSNMP) sendOneRequest(packetOut *SnmpPacket,
wait bool) (result *SnmpPacket, err error) {
allReqIDs := make([]uint32, 0, x.Retries+1)
// allMsgIDs := make([]uint32, 0, x.Retries+1) // unused
timeout := x.Timeout
withContextDeadline := false
sendRetry:
for retries := 0; ; retries++ {
if retries > 0 {
if x.OnRetry != nil {
x.OnRetry(x)
}
x.Logger.Printf("Retry number %d. Last error was: %v", retries, err)
if withContextDeadline && strings.Contains(err.Error(), "timeout") {
err = context.DeadlineExceeded
break
}
if retries > x.Retries {
if err == nil {
err = fmt.Errorf("max retries (%d) exceeded", x.Retries)
}
if strings.Contains(err.Error(), "timeout") {
err = fmt.Errorf("request timeout (after %d retries)", retries-1)
}
break
}
if x.ExponentialTimeout {
// https://www.webnms.com/snmp/help/snmpapi/snmpv3/v1/timeout.html
timeout *= 2
}
withContextDeadline = false
}
err = nil
if x.Context.Err() != nil {
return nil, x.Context.Err()
}
reqDeadline := time.Now().Add(timeout)
if contextDeadline, ok := x.Context.Deadline(); ok {
if contextDeadline.Before(reqDeadline) {
reqDeadline = contextDeadline
withContextDeadline = true
}
}
err = x.Conn.SetDeadline(reqDeadline)
if err != nil {
return nil, err
}
// Request ID is an atomic counter that wraps to 0 at max int32.
reqID := (atomic.AddUint32(&(x.requestID), 1) & 0x7FFFFFFF)
allReqIDs = append(allReqIDs, reqID)
packetOut.RequestID = reqID
if x.Version == Version3 {
msgID := (atomic.AddUint32(&(x.msgID), 1) & 0x7FFFFFFF)
// allMsgIDs = append(allMsgIDs, msgID) // unused
packetOut.MsgID = msgID
err = x.initPacket(packetOut)
if err != nil {
break
}
}
if x.Version == Version3 {
packetOut.SecurityParameters.Log()
}
var outBuf []byte
outBuf, err = packetOut.marshalMsg()
if err != nil {
// Don't retry - not going to get any better!
err = fmt.Errorf("marshal: %w", err)
break
}
if x.PreSend != nil {
x.PreSend(x)
}
x.Logger.Printf("SENDING PACKET: %s", packetOut.SafeString())
// If using UDP and unconnected socket, send packet directly to stored address.
if uconn, ok := x.Conn.(net.PacketConn); ok && x.uaddr != nil {
_, err = uconn.WriteTo(outBuf, x.uaddr)
} else {
_, err = x.Conn.Write(outBuf)
}
if err != nil {
continue
}
if x.OnSent != nil {
x.OnSent(x)
}
// all sends wait for the return packet, except for SNMPv2Trap
if !wait {
return &SnmpPacket{}, nil
}
waitingResponse:
for {
x.Logger.Print("WAITING RESPONSE...")
// Receive response and try receiving again on any decoding error.
// Let the deadline abort us if we don't receive a valid response.
var resp []byte
resp, err = x.receive()
if err == io.EOF && strings.HasPrefix(x.Transport, tcp) {
x.Logger.Printf("ERROR: EOF. Performing reconnect")
err = x.netConnect()
if err != nil {
return nil, err
}
continue sendRetry
} else if err != nil {
// receive error. retrying won't help. abort
break
}
if x.OnRecv != nil {
x.OnRecv(x)
}
x.Logger.Printf("GET RESPONSE OK: %+v", resp)
result = new(SnmpPacket)
result.Logger = x.Logger
result.MsgFlags = packetOut.MsgFlags
if packetOut.SecurityParameters != nil {
result.SecurityParameters = packetOut.SecurityParameters.Copy()
}
var cursor int
cursor, err = x.unmarshalHeader(resp, result)
if err != nil {
x.Logger.Printf("ERROR on unmarshall header: %s", err)
break
}
if x.Version == Version3 {
useResponseSecurityParameters := false
if usp, ok := x.SecurityParameters.(*UsmSecurityParameters); ok {
if usp.AuthoritativeEngineID == "" {
useResponseSecurityParameters = true
}
}
err = x.testAuthentication(resp, result, useResponseSecurityParameters)
if err != nil {
x.Logger.Printf("ERROR on Test Authentication on v3: %s", err)
break
}
resp, cursor, err = x.decryptPacket(resp, cursor, result)
if err != nil {
x.Logger.Printf("ERROR on decryptPacket on v3: %s", err)
break
}
}
err = x.unmarshalPayload(resp, cursor, result)
if err != nil {
x.Logger.Printf("ERROR on UnmarshalPayload on v3: %s", err)
break
}
if result.Error == NoError && len(result.Variables) < 1 {
x.Logger.Printf("ERROR on UnmarshalPayload on v3: Empty result")
break
}
// While Report PDU was defined by RFC 1905 as part of SNMPv2, it was never
// used until SNMPv3. Report PDU's allow a SNMP engine to tell another SNMP
// engine that an error was detected while processing an SNMP message.
//
// The format for a Report PDU is
// -----------------------------------
// | 0xA8 | reqid | 0 | 0 | varbinds |
// -----------------------------------
// where:
// - PDU type 0xA8 indicates a Report PDU.
// - reqid is either:
// The request identifier of the message that triggered the report
// or zero if the request identifier cannot be extracted.
// - The variable bindings will contain a single object identifier and its value
//
// usmStatsNotInTimeWindows and usmStatsUnknownEngineIDs are recoverable errors
// and will be retransmitted, for others we return the result with an error.
if result.Version == Version3 && result.PDUType == Report && len(result.Variables) == 1 {
switch result.Variables[0].Name {
case usmStatsUnsupportedSecLevels:
return result, ErrUnknownSecurityLevel
case usmStatsNotInTimeWindows:
break waitingResponse
case usmStatsUnknownUserNames:
return result, ErrUnknownUsername
case usmStatsUnknownEngineIDs:
break waitingResponse
case usmStatsWrongDigests:
return result, ErrWrongDigest
case usmStatsDecryptionErrors:
return result, ErrDecryption
case snmpUnknownSecurityModels:
return result, ErrUnknownSecurityModels
case snmpInvalidMsgs:
return result, ErrInvalidMsgs
case snmpUnknownPDUHandlers:
return result, ErrUnknownPDUHandlers
default:
return result, ErrUnknownReportPDU
}
}
validID := false
for _, id := range allReqIDs {
if id == result.RequestID {
validID = true
}
}
if result.RequestID == 0 {
validID = true
}
if !validID {
x.Logger.Print("ERROR out of order")
continue
}
break
}
if err != nil {
continue
}
if x.OnFinish != nil {
x.OnFinish(x)
}
// Success!
return result, nil
}
// Return last error
return nil, err
}
// generic "sender" that negotiate any version of snmp request
//
// all sends wait for the return packet, except for SNMPv2Trap
func (x *GoSNMP) send(packetOut *SnmpPacket, wait bool) (result *SnmpPacket, err error) {
defer func() {
if e := recover(); e != nil {
var buf = make([]byte, 8192)
runtime.Stack(buf, true)
err = fmt.Errorf("recover: %v Stack:%v", e, string(buf))
}
}()
if x.Conn == nil {
return nil, fmt.Errorf("&GoSNMP.Conn is missing. Provide a connection or use Connect()")
}
if x.Retries < 0 {
x.Retries = 0
}
x.Logger.Print("SEND INIT")
if packetOut.Version == Version3 {
x.Logger.Print("SEND INIT NEGOTIATE SECURITY PARAMS")
if err = x.negotiateInitialSecurityParameters(packetOut); err != nil {
return &SnmpPacket{}, err
}
x.Logger.Print("SEND END NEGOTIATE SECURITY PARAMS")
}
// perform request
result, err = x.sendOneRequest(packetOut, wait)
if err != nil {
x.Logger.Printf("SEND Error on the first Request Error: %s", err)
return result, err
}
if result.Version == Version3 {
x.Logger.Printf("SEND STORE SECURITY PARAMS from result: %s", result.SecurityParameters.SafeString())
err = x.storeSecurityParameters(result)
if result.PDUType == Report && len(result.Variables) == 1 {
switch result.Variables[0].Name {
case usmStatsNotInTimeWindows:
x.Logger.Print("WARNING detected out-of-time-window ERROR")
if err = x.updatePktSecurityParameters(packetOut); err != nil {
x.Logger.Printf("ERROR updatePktSecurityParameters error: %s", err)
return nil, err
}
// retransmit with updated auth engine params
result, err = x.sendOneRequest(packetOut, wait)
if err != nil {
x.Logger.Printf("ERROR out-of-time-window retransmit error: %s", err)
return result, ErrNotInTimeWindow
}
case usmStatsUnknownEngineIDs:
x.Logger.Print("WARNING detected unknown engine id ERROR")
if err = x.updatePktSecurityParameters(packetOut); err != nil {
x.Logger.Printf("ERROR updatePktSecurityParameters error: %s", err)
return nil, err
}
// retransmit with updated engine id
result, err = x.sendOneRequest(packetOut, wait)
if err != nil {
x.Logger.Printf("ERROR unknown engine id retransmit error: %s", err)
return result, ErrUnknownEngineID
}
}
}
}
return result, err
}
// -- Marshalling Logic --------------------------------------------------------
// MarshalMsg marshalls a snmp packet, ready for sending across the wire
func (packet *SnmpPacket) MarshalMsg() ([]byte, error) {
return packet.marshalMsg()
}
// marshal an SNMP message
func (packet *SnmpPacket) marshalMsg() ([]byte, error) {
var err error
buf := new(bytes.Buffer)
// version
buf.Write([]byte{2, 1, byte(packet.Version)})
if packet.Version == Version3 {
buf, err = packet.marshalV3(buf)
if err != nil {
return nil, err
}
} else {
// community
buf.Write([]byte{4, uint8(len(packet.Community))}) //nolint:gosec
buf.WriteString(packet.Community)
// pdu
pdu, err2 := packet.marshalPDU()
if err2 != nil {
return nil, err2
}
buf.Write(pdu)
}
// build up resulting msg - sequence, length then the tail (buf)
msg := new(bytes.Buffer)
msg.WriteByte(byte(Sequence))
bufLengthBytes, err2 := marshalLength(buf.Len())
if err2 != nil {
return nil, err2
}
msg.Write(bufLengthBytes)
_, err = buf.WriteTo(msg)
if err != nil {
return nil, err
}
authenticatedMessage, err := packet.authenticate(msg.Bytes())
if err != nil {
return nil, err
}
return authenticatedMessage, nil
}
func (packet *SnmpPacket) marshalSNMPV1TrapHeader() ([]byte, error) {
buf := new(bytes.Buffer)
// marshal OID
oidBytes, err := marshalObjectIdentifier(packet.Enterprise)
if err != nil {
return nil, fmt.Errorf("unable to marshal OID: %w", err)
}
if err = marshalTLV(buf, byte(ObjectIdentifier), oidBytes); err != nil {
return nil, err
}
// marshal AgentAddress (ip address)
ip := net.ParseIP(packet.AgentAddress)
ipAddressBytes := ipv4toBytes(ip)
buf.Write([]byte{byte(IPAddress), byte(len(ipAddressBytes))})
buf.Write(ipAddressBytes)
// marshal GenericTrap. Could just cast GenericTrap to a single byte as IDs greater than 6 are unknown,
// but do it properly. See issue 182.
var genericTrapBytes []byte
genericTrapBytes, err = marshalInt32(packet.GenericTrap)
if err != nil {
return nil, fmt.Errorf("unable to marshal SNMPv1 GenericTrap: %w", err)
}
buf.Write([]byte{byte(Integer), byte(len(genericTrapBytes))})
buf.Write(genericTrapBytes)
// marshal SpecificTrap
var specificTrapBytes []byte
specificTrapBytes, err = marshalInt32(packet.SpecificTrap)
if err != nil {
return nil, fmt.Errorf("unable to marshal SNMPv1 SpecificTrap: %w", err)
}
buf.Write([]byte{byte(Integer), byte(len(specificTrapBytes))})
buf.Write(specificTrapBytes)
// marshal timeTicks
timeTickBytes, err := marshalUint32(packet.Timestamp)
if err != nil {
return nil, fmt.Errorf("unable to Timestamp: %w", err)
}
buf.Write([]byte{byte(TimeTicks), byte(len(timeTickBytes))})
buf.Write(timeTickBytes)
return buf.Bytes(), nil
}
// marshal a PDU
func (packet *SnmpPacket) marshalPDU() ([]byte, error) {
buf := new(bytes.Buffer)
switch packet.PDUType {
case GetBulkRequest:
// requestid
err := shrinkAndWriteUint(buf, int(packet.RequestID))
if err != nil {
return nil, err
}
// non repeaters
nonRepeaters, err := marshalUint32(packet.NonRepeaters)
if err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal NonRepeaters to uint32: %w", err)
}
buf.Write([]byte{2, byte(len(nonRepeaters))})
if err = binary.Write(buf, binary.BigEndian, nonRepeaters); err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal NonRepeaters: %w", err)
}
// max repetitions
maxRepetitions, err := marshalUint32(packet.MaxRepetitions)
if err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal maxRepetitions to uint32: %w", err)
}
buf.Write([]byte{2, byte(len(maxRepetitions))})
if err = binary.Write(buf, binary.BigEndian, maxRepetitions); err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal maxRepetitions: %w", err)
}
case Trap:
// write SNMP V1 Trap Header fields
snmpV1TrapHeader, err := packet.marshalSNMPV1TrapHeader()
if err != nil {
return nil, err
}
buf.Write(snmpV1TrapHeader)
default:
// requestid
err := shrinkAndWriteUint(buf, int(packet.RequestID))
if err != nil {
return nil, err
}
// error status
errorStatus, err := marshalUint32(packet.Error)
if err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal errorStatus to uint32: %w", err)
}
buf.Write([]byte{2, byte(len(errorStatus))})
if err = binary.Write(buf, binary.BigEndian, errorStatus); err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal errorStatus: %w", err)
}
// error index
errorIndex, err := marshalUint32(packet.ErrorIndex)
if err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal errorIndex to uint32: %w", err)
}
buf.Write([]byte{2, byte(len(errorIndex))})
if err = binary.Write(buf, binary.BigEndian, errorIndex); err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal errorIndex: %w", err)
}
}
// build varbind list
vbl, err := packet.marshalVBL()
if err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal varbind list: %w", err)
}
buf.Write(vbl)
// build up resulting pdu
pdu := new(bytes.Buffer)
// calculate pdu length
bufLengthBytes, err := marshalLength(buf.Len())
if err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal pdu length: %w", err)
}
// write request type
pdu.WriteByte(byte(packet.PDUType))
// write pdu length
pdu.Write(bufLengthBytes)
// write the tail (buf)
if _, err = buf.WriteTo(pdu); err != nil {
return nil, fmt.Errorf("marshalPDU: unable to marshal pdu: %w", err)
}
return pdu.Bytes(), nil
}
// marshal a varbind list
func (packet *SnmpPacket) marshalVBL() ([]byte, error) {
vblBuf := new(bytes.Buffer)
for _, pdu := range packet.Variables {
// The copy of the 'for' variable "pdu" can be deleted (Go 1.22+)
vb, err := marshalVarbind(&pdu)
if err != nil {
return nil, err
}
vblBuf.Write(vb)
}
vblBytes := vblBuf.Bytes()
vblLengthBytes, err := marshalLength(len(vblBytes))
if err != nil {
return nil, err
}
// FIX does bytes.Buffer give better performance than byte slices?
result := []byte{byte(Sequence)}
result = append(result, vblLengthBytes...)
result = append(result, vblBytes...)
return result, nil
}
// marshalVarbind encodes an SNMP variable binding (varbind) as BER.
// Returns a Sequence TLV containing the OID and its associated value:
//
// Sequence {
// ObjectIdentifier (pdu.Name)
// <Value TLV> (pdu.Type + pdu.Value)
// }
func marshalVarbind(pdu *SnmpPDU) ([]byte, error) {
oid, err := marshalObjectIdentifier(pdu.Name)
if err != nil {
return nil, err
}
pduBuf := new(bytes.Buffer)
tmpBuf := new(bytes.Buffer)
// Marshal the PDU type into the appropriate BER
switch pdu.Type {
case Null:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
tmpBuf.WriteByte(byte(Null))
tmpBuf.WriteByte(byte(EndOfContents))
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case Integer:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
// Number
var intBytes []byte
switch value := pdu.Value.(type) {
case byte:
intBytes = []byte{byte(pdu.Value.(int))}
case int:
if intBytes, err = marshalInt32(value); err != nil {
return nil, fmt.Errorf("error mashalling PDU Integer: %w", err)
}
default:
return nil, fmt.Errorf("unable to marshal PDU Integer; not byte or int")
}
if err = marshalTLV(tmpBuf, byte(pdu.Type), intBytes); err != nil {
return nil, err
}
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case Counter32, Gauge32, TimeTicks, Uinteger32:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
// Number
var intBytes []byte
switch value := pdu.Value.(type) {
case uint32:
if intBytes, err = marshalUint32(value); err != nil {
return nil, fmt.Errorf("error marshalling PDU Uinteger32 type from uint32: %w", err)
}
case uint:
if intBytes, err = marshalUint32(value); err != nil {
return nil, fmt.Errorf("error marshalling PDU Uinteger32 type from uint: %w", err)
}
default:
return nil, fmt.Errorf("unable to marshal pdu.Type %v; unknown pdu.Value %v[type=%T]", pdu.Type, pdu.Value, pdu.Value)
}
if err = marshalTLV(tmpBuf, byte(pdu.Type), intBytes); err != nil {
return nil, err
}
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case OctetString, BitString, Opaque:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
// OctetString
var octetStringBytes []byte
switch value := pdu.Value.(type) {
case []byte:
octetStringBytes = value
case string:
octetStringBytes = []byte(value)
default:
return nil, fmt.Errorf("unable to marshal PDU OctetString; not []byte or string")
}
if err = marshalTLV(tmpBuf, byte(pdu.Type), octetStringBytes); err != nil {
return nil, err
}
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case ObjectIdentifier:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
value := pdu.Value.(string)
oidBytes, encErr := marshalObjectIdentifier(value)
if encErr != nil {
return nil, fmt.Errorf("error marshalling ObjectIdentifier: %w", encErr)
}
if err = marshalTLV(tmpBuf, byte(pdu.Type), oidBytes); err != nil {
return nil, err
}
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case IPAddress:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
// OctetString
var ipAddressBytes []byte
switch value := pdu.Value.(type) {
case []byte:
ipAddressBytes = value
case string:
ip := net.ParseIP(value)
ipAddressBytes = ipv4toBytes(ip)
default:
return nil, fmt.Errorf("unable to marshal PDU IPAddress; not []byte or string")
}
if err = marshalTLV(tmpBuf, byte(pdu.Type), ipAddressBytes); err != nil {
return nil, err
}
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case OpaqueFloat, OpaqueDouble:
converters := map[Asn1BER]func(any) ([]byte, error){
OpaqueFloat: marshalFloat32,
OpaqueDouble: marshalFloat64,
}
intBuf := new(bytes.Buffer)
intBuf.WriteByte(byte(AsnExtensionTag))
intBuf.WriteByte(byte(pdu.Type))
intBytes, encErr := converters[pdu.Type](pdu.Value)
if encErr != nil {
return nil, fmt.Errorf("error converting PDU value type %v to %v: %w", pdu.Value, pdu.Type, encErr)
}
intLength, encErr := marshalLength(len(intBytes))
if encErr != nil {
return nil, fmt.Errorf("error marshalling Float type length: %w", encErr)
}
intBuf.Write(intLength)
intBuf.Write(intBytes)
opaqueLength, encErr := marshalLength(len(intBuf.Bytes()))
if encErr != nil {
return nil, fmt.Errorf("error marshalling Opaque length: %w", encErr)
}
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
tmpBuf.WriteByte(byte(Opaque))
tmpBuf.Write(opaqueLength)
tmpBuf.Write(intBuf.Bytes())
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case Counter64:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
intBytes, encErr := marshalUint64(pdu.Value)
if encErr != nil {
return nil, fmt.Errorf("error marshalling Counter64: %w", encErr)
}
if err = marshalTLV(tmpBuf, byte(pdu.Type), intBytes); err != nil {
return nil, err
}
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
case NoSuchInstance, NoSuchObject, EndOfMibView:
if err = marshalTLV(tmpBuf, byte(ObjectIdentifier), oid); err != nil {
return nil, err
}
tmpBuf.WriteByte(byte(pdu.Type))
tmpBuf.WriteByte(byte(EndOfContents))
if err = marshalTLV(pduBuf, byte(Sequence), tmpBuf.Bytes()); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unable to marshal PDU: unknown BER type %q", pdu.Type)
}
return pduBuf.Bytes(), nil
}
// -- Unmarshalling Logic ------------------------------------------------------
func (x *GoSNMP) unmarshalVersionFromHeader(packet []byte, response *SnmpPacket) (SnmpVersion, int, error) {
if len(packet) < 2 {
return 0, 0, fmt.Errorf("cannot unmarshal empty packet")
}
if response == nil {
return 0, 0, fmt.Errorf("cannot unmarshal response into nil packet reference")
}
response.Variables = make([]SnmpPDU, 0, 5)
// Start parsing the packet
cursor := 0
// First bytes should be 0x30
if PDUType(packet[0]) != Sequence {
return 0, 0, fmt.Errorf("invalid packet header")
}
length, cursor, err := parseLength(packet)
if err != nil {
return 0, 0, err
}
if len(packet) != length {
return 0, 0, fmt.Errorf("error verifying packet sanity: Got %d Expected: %d", len(packet), length)
}
x.Logger.Printf("Packet sanity verified, we got all the bytes (%d)", length)
// Parse SNMP Version
rawVersion, count, err := parseRawField(x.Logger, packet[cursor:], "version")
if err != nil {
return 0, 0, fmt.Errorf("error parsing SNMP packet version: %w", err)
}
cursor += count
if cursor >= len(packet) {
return 0, 0, fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if version, ok := rawVersion.(int); ok {
x.Logger.Printf("Parsed version %d", version)
return SnmpVersion(version), cursor, nil //nolint:gosec
}
return 0, cursor, err
}
func (x *GoSNMP) unmarshalHeader(packet []byte, response *SnmpPacket) (int, error) {
version, cursor, err := x.unmarshalVersionFromHeader(packet, response)
if err != nil {
return 0, err
}
response.Version = version
if response.Version == Version3 {
oldcursor := cursor
cursor, err = x.unmarshalV3Header(packet, cursor, response)
if err != nil {
return 0, err
}
x.Logger.Printf("UnmarshalV3Header done. [with SecurityParameters]. Header Size %d. Last 4 Bytes=[%v]", cursor-oldcursor, packet[cursor-4:cursor])
} else {
// Parse community
rawCommunity, count, err := parseRawField(x.Logger, packet[cursor:], "community")
if err != nil {
return 0, fmt.Errorf("error parsing community string: %w", err)
}
cursor += count
if cursor > len(packet) {
return 0, fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if community, ok := rawCommunity.(string); ok {
response.Community = community
x.Logger.Printf("Parsed community %s", community)
}
}
return cursor, nil
}
func (x *GoSNMP) unmarshalPayload(packet []byte, cursor int, response *SnmpPacket) error {
if len(packet) == 0 {
return errors.New("cannot unmarshal nil or empty payload packet")
}
if cursor >= len(packet) {
return fmt.Errorf("cannot unmarshal payload, packet length %d cursor %d", len(packet), cursor)
}
if response == nil {
return errors.New("cannot unmarshal payload response into nil packet reference")
}
// Parse SNMP packet type
requestType := PDUType(packet[cursor])
x.Logger.Printf("UnmarshalPayload Meet PDUType %#x. Offset %v", requestType, cursor)
switch requestType {
// known, supported types
case GetResponse, GetNextRequest, GetBulkRequest, Report, SNMPv2Trap, GetRequest, SetRequest, InformRequest:
response.PDUType = requestType
if err := x.unmarshalResponse(packet[cursor:], response); err != nil {
return fmt.Errorf("error in unmarshalResponse: %w", err)
}
// If it's an InformRequest, mark the trap.
response.IsInform = (requestType == InformRequest)
case Trap:
response.PDUType = requestType
if err := x.unmarshalTrapV1(packet[cursor:], response); err != nil {
return fmt.Errorf("error in unmarshalTrapV1: %w", err)
}
default:
x.Logger.Printf("UnmarshalPayload Meet Unknown PDUType %#x. Offset %v", requestType, cursor)
return fmt.Errorf("unknown PDUType %#x", requestType)
}
return nil
}
func (x *GoSNMP) unmarshalResponse(packet []byte, response *SnmpPacket) error {
cursor := 0
getResponseLength, cursor, err := parseLength(packet)
if err != nil {
return err
}
if len(packet) != getResponseLength {
return fmt.Errorf("error verifying Response sanity: Got %d Expected: %d", len(packet), getResponseLength)
}
x.Logger.Printf("getResponseLength: %d", getResponseLength)
// Parse Request-ID
rawRequestID, count, err := parseRawField(x.Logger, packet[cursor:], "request id")
if err != nil {
return fmt.Errorf("error parsing SNMP packet request ID: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if requestid, ok := rawRequestID.(int); ok {
response.RequestID = uint32(requestid) //nolint:gosec
x.Logger.Printf("requestID: %d", response.RequestID)
}
if response.PDUType == GetBulkRequest {
// Parse Non Repeaters
rawNonRepeaters, count, err := parseRawField(x.Logger, packet[cursor:], "non repeaters")
if err != nil {
return fmt.Errorf("error parsing SNMP packet non repeaters: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if nonRepeaters, ok := rawNonRepeaters.(int); ok {
response.NonRepeaters = uint8(nonRepeaters) //nolint:gosec
}
// Parse Max Repetitions
rawMaxRepetitions, count, err := parseRawField(x.Logger, packet[cursor:], "max repetitions")
if err != nil {
return fmt.Errorf("error parsing SNMP packet max repetitions: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if maxRepetitions, ok := rawMaxRepetitions.(int); ok {
response.MaxRepetitions = uint32(maxRepetitions & 0x7FFFFFFF) //nolint:gosec
}
} else {
// Parse Error-Status
rawError, count, err := parseRawField(x.Logger, packet[cursor:], "error-status")
if err != nil {
return fmt.Errorf("error parsing SNMP packet error: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if errorStatus, ok := rawError.(int); ok {
response.Error = SNMPError(errorStatus) //nolint:gosec
x.Logger.Printf("errorStatus: %d", uint8(errorStatus)) //nolint:gosec
}
// Parse Error-Index
rawErrorIndex, count, err := parseRawField(x.Logger, packet[cursor:], "error index")
if err != nil {
return fmt.Errorf("error parsing SNMP packet error index: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if errorindex, ok := rawErrorIndex.(int); ok {
response.ErrorIndex = uint8(errorindex) //nolint:gosec
x.Logger.Printf("error-index: %d", uint8(errorindex)) //nolint:gosec
}
}
return x.unmarshalVBL(packet[cursor:], response)
}
func (x *GoSNMP) unmarshalTrapV1(packet []byte, response *SnmpPacket) error {
cursor := 0
getResponseLength, cursor, err := parseLength(packet)
if err != nil {
return err
}
if len(packet) != getResponseLength {
return fmt.Errorf("error verifying Response sanity: Got %d Expected: %d", len(packet), getResponseLength)
}
x.Logger.Printf("getResponseLength: %d", getResponseLength)
// Parse Enterprise
rawEnterprise, count, err := parseRawField(x.Logger, packet[cursor:], "enterprise")
if err != nil {
return fmt.Errorf("error parsing SNMP packet error: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if Enterprise, ok := rawEnterprise.(string); ok {
response.Enterprise = Enterprise
x.Logger.Printf("Enterprise: %+v", Enterprise)
}
// Parse AgentAddress
rawAgentAddress, count, err := parseRawField(x.Logger, packet[cursor:], "agent-address")
if err != nil {
return fmt.Errorf("error parsing SNMP packet error: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if AgentAddress, ok := rawAgentAddress.(string); ok {
response.AgentAddress = AgentAddress
x.Logger.Printf("AgentAddress: %s", AgentAddress)
}
// Parse GenericTrap
rawGenericTrap, count, err := parseRawField(x.Logger, packet[cursor:], "generic-trap")
if err != nil {
return fmt.Errorf("error parsing SNMP packet error: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if GenericTrap, ok := rawGenericTrap.(int); ok {
response.GenericTrap = GenericTrap
x.Logger.Printf("GenericTrap: %d", GenericTrap)
}
// Parse SpecificTrap
rawSpecificTrap, count, err := parseRawField(x.Logger, packet[cursor:], "specific-trap")
if err != nil {
return fmt.Errorf("error parsing SNMP packet error: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if SpecificTrap, ok := rawSpecificTrap.(int); ok {
response.SpecificTrap = SpecificTrap
x.Logger.Printf("SpecificTrap: %d", SpecificTrap)
}
// Parse TimeStamp
rawTimestamp, count, err := parseRawField(x.Logger, packet[cursor:], "time-stamp")
if err != nil {
return fmt.Errorf("error parsing SNMP packet error: %w", err)
}
cursor += count
if cursor > len(packet) {
return fmt.Errorf("error parsing SNMP packet, packet length %d cursor %d", len(packet), cursor)
}
if Timestamp, ok := rawTimestamp.(uint); ok {
response.Timestamp = Timestamp
x.Logger.Printf("Timestamp: %d", Timestamp)
}
return x.unmarshalVBL(packet[cursor:], response)
}
// unmarshal a Varbind list
func (x *GoSNMP) unmarshalVBL(packet []byte, response *SnmpPacket) error {
var cursor, cursorInc int
var vblLength int
if len(packet) == 0 || cursor > len(packet) {
return fmt.Errorf("truncated packet when unmarshalling a VBL, got length %d cursor %d", len(packet), cursor)
}
if packet[cursor] != 0x30 {
return fmt.Errorf("expected a sequence when unmarshalling a VBL, got %x", packet[cursor])
}
vblLength, cursor, err := parseLength(packet)
if err != nil {
return err
}
if vblLength == 0 || vblLength > len(packet) {
return fmt.Errorf("truncated packet when unmarshalling a VBL, packet length %d cursor %d", len(packet), cursor)
}
if len(packet) != vblLength {
return fmt.Errorf("error verifying: packet length %d vbl length %d", len(packet), vblLength)
}
x.Logger.Printf("vblLength: %d", vblLength)
// check for an empty response
if vblLength == 2 && packet[1] == 0x00 {
return nil
}
// Loop & parse Varbinds
for cursor < vblLength {
if packet[cursor] != 0x30 {
return fmt.Errorf("expected a sequence when unmarshalling a VB, got %x", packet[cursor])
}
_, cursorInc, err = parseLength(packet[cursor:])
if err != nil {
return err
}
cursor += cursorInc
if cursor > len(packet) {
return fmt.Errorf("error parsing OID Value: packet %d cursor %d", len(packet), cursor)
}
// Parse OID
rawOid, oidLength, err := parseRawField(x.Logger, packet[cursor:], "OID")
if err != nil {
return fmt.Errorf("error parsing OID Value: %w", err)
}
cursor += oidLength
if cursor > len(packet) {
return fmt.Errorf("error parsing OID Value: truncated, packet length %d cursor %d", len(packet), cursor)
}
oid, ok := rawOid.(string)
if !ok {
return fmt.Errorf("unable to type assert rawOid |%v| to string", rawOid)
}
x.Logger.Printf("OID: %s", oid)
// Parse Value
var decodedVal variable
if err = x.decodeValue(packet[cursor:], &decodedVal); err != nil {
return fmt.Errorf("error decoding value: %w", err)
}
valueLength, _, err := parseLength(packet[cursor:])
if err != nil {
return err
}
cursor += valueLength
if cursor > len(packet) {
return fmt.Errorf("error decoding OID Value: truncated, packet length %d cursor %d", len(packet), cursor)
}
response.Variables = append(response.Variables, SnmpPDU{Name: oid, Type: decodedVal.Type, Value: decodedVal.Value})
}
return nil
}
// receive response from network and read into a byte array
func (x *GoSNMP) receive() ([]byte, error) {
var n int
var err error
// If we are using UDP and unconnected socket, read the packet and
// disregard the source address.
if uconn, ok := x.Conn.(net.PacketConn); ok {
n, _, err = uconn.ReadFrom(x.rxBuf[:])
} else {
n, err = x.Conn.Read(x.rxBuf[:])
}
if err == io.EOF {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("error reading from socket: %w", err)
}
if n == rxBufSize {
// This should never happen unless we're using something like a unix domain socket.
return nil, fmt.Errorf("response buffer too small")
}
resp := make([]byte, n)
copy(resp, x.rxBuf[:n])
return resp, nil
}
func shrinkAndWriteUint(buf io.Writer, in int) error {
out, err := asn1.Marshal(in)
if err != nil {
return err
}
_, err = buf.Write(out)
return err
}
// Code generated by "stringer -type=PDUType"; DO NOT EDIT.
package gosnmp
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Sequence-48]
_ = x[GetRequest-160]
_ = x[GetNextRequest-161]
_ = x[GetResponse-162]
_ = x[SetRequest-163]
_ = x[Trap-164]
_ = x[GetBulkRequest-165]
_ = x[InformRequest-166]
_ = x[SNMPv2Trap-167]
_ = x[Report-168]
}
const (
_PDUType_name_0 = "Sequence"
_PDUType_name_1 = "GetRequestGetNextRequestGetResponseSetRequestTrapGetBulkRequestInformRequestSNMPv2TrapReport"
)
var (
_PDUType_index_1 = [...]uint8{0, 10, 24, 35, 45, 49, 63, 76, 86, 92}
)
func (i PDUType) String() string {
switch {
case i == 48:
return _PDUType_name_0
case 160 <= i && i <= 168:
i -= 160
return _PDUType_name_1[_PDUType_index_1[i]:_PDUType_index_1[i+1]]
default:
return "PDUType(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
// Code generated by "stringer -type SNMPError"; DO NOT EDIT.
package gosnmp
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[NoError-0]
_ = x[TooBig-1]
_ = x[NoSuchName-2]
_ = x[BadValue-3]
_ = x[ReadOnly-4]
_ = x[GenErr-5]
_ = x[NoAccess-6]
_ = x[WrongType-7]
_ = x[WrongLength-8]
_ = x[WrongEncoding-9]
_ = x[WrongValue-10]
_ = x[NoCreation-11]
_ = x[InconsistentValue-12]
_ = x[ResourceUnavailable-13]
_ = x[CommitFailed-14]
_ = x[UndoFailed-15]
_ = x[AuthorizationError-16]
_ = x[NotWritable-17]
_ = x[InconsistentName-18]
}
const _SNMPError_name = "NoErrorTooBigNoSuchNameBadValueReadOnlyGenErrNoAccessWrongTypeWrongLengthWrongEncodingWrongValueNoCreationInconsistentValueResourceUnavailableCommitFailedUndoFailedAuthorizationErrorNotWritableInconsistentName"
var _SNMPError_index = [...]uint8{0, 7, 13, 23, 31, 39, 45, 53, 62, 73, 86, 96, 106, 123, 142, 154, 164, 182, 193, 209}
func (i SNMPError) String() string {
idx := int(i) - 0
if i < 0 || idx >= len(_SNMPError_index)-1 {
return "SNMPError(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _SNMPError_name[_SNMPError_index[idx]:_SNMPError_index[idx+1]]
}
// Code generated by "stringer -type=SnmpV3AuthProtocol"; DO NOT EDIT.
package gosnmp
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[NoAuth-1]
_ = x[MD5-2]
_ = x[SHA-3]
_ = x[SHA224-4]
_ = x[SHA256-5]
_ = x[SHA384-6]
_ = x[SHA512-7]
}
const _SnmpV3AuthProtocol_name = "NoAuthMD5SHASHA224SHA256SHA384SHA512"
var _SnmpV3AuthProtocol_index = [...]uint8{0, 6, 9, 12, 18, 24, 30, 36}
func (i SnmpV3AuthProtocol) String() string {
idx := int(i) - 1
if i < 1 || idx >= len(_SnmpV3AuthProtocol_index)-1 {
return "SnmpV3AuthProtocol(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _SnmpV3AuthProtocol_name[_SnmpV3AuthProtocol_index[idx]:_SnmpV3AuthProtocol_index[idx+1]]
}
// Code generated by "stringer -type=SnmpV3MsgFlags"; DO NOT EDIT.
package gosnmp
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[NoAuthNoPriv-0]
_ = x[AuthNoPriv-1]
_ = x[AuthPriv-3]
_ = x[Reportable-4]
}
const (
_SnmpV3MsgFlags_name_0 = "NoAuthNoPrivAuthNoPriv"
_SnmpV3MsgFlags_name_1 = "AuthPrivReportable"
)
var (
_SnmpV3MsgFlags_index_0 = [...]uint8{0, 12, 22}
_SnmpV3MsgFlags_index_1 = [...]uint8{0, 8, 18}
)
func (i SnmpV3MsgFlags) String() string {
switch {
case i <= 1:
return _SnmpV3MsgFlags_name_0[_SnmpV3MsgFlags_index_0[i]:_SnmpV3MsgFlags_index_0[i+1]]
case 3 <= i && i <= 4:
i -= 3
return _SnmpV3MsgFlags_name_1[_SnmpV3MsgFlags_index_1[i]:_SnmpV3MsgFlags_index_1[i+1]]
default:
return "SnmpV3MsgFlags(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
// Code generated by "stringer -type=SnmpV3PrivProtocol"; DO NOT EDIT.
package gosnmp
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[NoPriv-1]
_ = x[DES-2]
_ = x[AES-3]
_ = x[AES192-4]
_ = x[AES256-5]
_ = x[AES192C-6]
_ = x[AES256C-7]
}
const _SnmpV3PrivProtocol_name = "NoPrivDESAESAES192AES256AES192CAES256C"
var _SnmpV3PrivProtocol_index = [...]uint8{0, 6, 9, 12, 18, 24, 31, 38}
func (i SnmpV3PrivProtocol) String() string {
idx := int(i) - 1
if i < 1 || idx >= len(_SnmpV3PrivProtocol_index)-1 {
return "SnmpV3PrivProtocol(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _SnmpV3PrivProtocol_name[_SnmpV3PrivProtocol_index[idx]:_SnmpV3PrivProtocol_index[idx+1]]
}
// Code generated by "stringer -type=SnmpV3SecurityModel"; DO NOT EDIT.
package gosnmp
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[UserSecurityModel-3]
}
const _SnmpV3SecurityModel_name = "UserSecurityModel"
var _SnmpV3SecurityModel_index = [...]uint8{0, 17}
func (i SnmpV3SecurityModel) String() string {
idx := int(i) - 3
if i < 3 || idx >= len(_SnmpV3SecurityModel_index)-1 {
return "SnmpV3SecurityModel(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _SnmpV3SecurityModel_name[_SnmpV3SecurityModel_index[idx]:_SnmpV3SecurityModel_index[idx+1]]
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
package gosnmp
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
//
// Sending Traps ie GoSNMP acting as an Agent
//
// SendTrap sends a SNMP Trap
//
// pdus[0] can a pdu of Type TimeTicks (with the desired uint32 epoch
// time). Otherwise a TimeTicks pdu will be prepended, with time set to
// now. This mirrors the behaviour of the Net-SNMP command-line tools.
//
// SendTrap doesn't wait for a return packet from the NMS (Network
// Management Station).
//
// See also Listen() and examples for creating an NMS.
//
// NOTE: the trap code is currently unreliable when working with snmpv3 - pull requests welcome
func (x *GoSNMP) SendTrap(trap SnmpTrap) (result *SnmpPacket, err error) {
var pdutype PDUType
switch x.Version {
case Version2c, Version3:
// Default to a v2 trap.
pdutype = SNMPv2Trap
if len(trap.Variables) == 0 {
return nil, fmt.Errorf("function SendTrap requires at least 1 PDU")
}
if trap.Variables[0].Type == TimeTicks {
// check is uint32
if _, ok := trap.Variables[0].Value.(uint32); !ok {
return nil, fmt.Errorf("function SendTrap TimeTick must be uint32")
}
}
switch x.MsgFlags {
// as per https://www.rfc-editor.org/rfc/rfc3412.html#section-6.4
// The reportableFlag MUST always be zero when the message contains
// a PDU from the Unconfirmed Class such as an SNMPv2-trap PDU
case 0x4, 0x5, 0x7:
// .. therefor bitclear the Reportable flag from the MsgFlags
// that we inherited from validateParameters()
x.MsgFlags = (x.MsgFlags &^ Reportable)
}
// If it's an inform, do that instead.
if trap.IsInform {
pdutype = InformRequest
// Per RFC 3414 Section 4:
// When sending an SNMPv3 InformRequest, the Reportable flag MUST be set in MsgFlags.
// This ensures that the authoritative engine will return a Report PDU containing
// engineBoots and engineTime for time synchronization which is required before
// authenticated communication can succeed. Without this, the engine may reject
// the Inform as out-of-time-window or unknown engine.
x.MsgFlags = (x.MsgFlags | Reportable)
}
if trap.Variables[0].Type != TimeTicks {
now := uint32(time.Now().Unix()) //nolint:gosec
timetickPDU := SnmpPDU{Name: "1.3.6.1.2.1.1.3.0", Type: TimeTicks, Value: now}
// prepend timetickPDU
trap.Variables = append([]SnmpPDU{timetickPDU}, trap.Variables...)
}
case Version1:
pdutype = Trap
if len(trap.Enterprise) == 0 {
return nil, fmt.Errorf("function SendTrap for SNMPV1 requires an Enterprise OID")
}
if len(trap.AgentAddress) == 0 {
return nil, fmt.Errorf("function SendTrap for SNMPV1 requires an Agent Address")
}
default:
err = fmt.Errorf("function SendTrap doesn't support %s", x.Version)
return nil, err
}
packetOut := x.mkSnmpPacket(pdutype, trap.Variables, 0, 0)
if x.Version == Version1 {
packetOut.Enterprise = trap.Enterprise
packetOut.AgentAddress = trap.AgentAddress
packetOut.GenericTrap = trap.GenericTrap
packetOut.SpecificTrap = trap.SpecificTrap
packetOut.Timestamp = trap.Timestamp
}
// all sends wait for the return packet, except for SNMPv2Trap
// -> wait is only for informs
return x.send(packetOut, trap.IsInform)
}
//
// Receiving Traps ie GoSNMP acting as an NMS (Network Management
// Station).
//
// GoSNMP.unmarshal() currently only handles SNMPv2Trap
//
// A TrapListener defines parameters for running a SNMP Trap receiver.
// nil values will be replaced by default values.
type TrapListener struct {
done chan bool
listening chan bool
sync.Mutex
// Params is a reference to the TrapListener's "parent" GoSNMP instance.
Params *GoSNMP
// OnNewTrap handles incoming Trap and Inform PDUs.
OnNewTrap TrapHandlerFunc
// CloseTimeout is the max wait time for the socket to gracefully signal its closure.
CloseTimeout time.Duration
// These unexported fields are for letting test cases
// know we are ready.
conn *net.UDPConn
proto string
// Total number of packets received referencing an unknown snmpEngineID
usmStatsUnknownEngineIDsCount uint32
finish int32 // Atomic flag; set to 1 when closing connection
buffSize uint // SNMP message buffer size
}
// Default timeout value for CloseTimeout of 3 seconds
const defaultCloseTimeout = 3 * time.Second
// TrapHandlerFunc is a callback function type which receives SNMP Trap and
// Inform packets when they are received. If this callback is null, Trap and
// Inform PDUs will not be received (Inform responses will still be sent,
// however). This callback should not modify the contents of the SnmpPacket
// nor the UDPAddr passed to it, and it should copy out any values it wishes to
// use instead of retaining references in order to avoid memory fragmentation.
//
// The general effect of received Trap and Inform packets do not differ for the
// receiver, and the response is handled by the caller of the handler, so there
// is no need for the application to handle Informs any different than Traps.
// Nonetheless, the packet's Type field can be examined to determine what type
// of event this is for e.g. statistics gathering functions, etc.
type TrapHandlerFunc func(s *SnmpPacket, u *net.UDPAddr)
// NewTrapListener returns an initialized TrapListener.
//
// NOTE: the trap code is currently unreliable when working with snmpv3 - pull requests welcome
func NewTrapListener() *TrapListener {
tl := &TrapListener{
finish: 0,
buffSize: 4096,
done: make(chan bool),
listening: make(chan bool, 1), // Buffered because one doesn't have to block on it.
CloseTimeout: defaultCloseTimeout,
}
return tl
}
// WithBufferSize changes the snmp message buffer size of the current TrapListener
//
// NOTE: The buffer size cannot be 0 bytes, the default size is 4096 bytes
func (t *TrapListener) WithBufferSize(i uint) *TrapListener {
if i < 1 {
i = 1
}
t.buffSize = i
return t
}
// Listening returns a sentinel channel on which one can block
// until the listener is ready to receive requests.
//
// NOTE: the trap code is currently unreliable when working with snmpv3 - pull requests welcome
func (t *TrapListener) Listening() <-chan bool {
t.Lock()
defer t.Unlock()
return t.listening
}
// Close terminates the listening on TrapListener socket
func (t *TrapListener) Close() {
if atomic.CompareAndSwapInt32(&t.finish, 0, 1) {
t.Lock()
defer t.Unlock()
if t.conn == nil {
return
}
if err := t.conn.Close(); err != nil {
t.Params.Logger.Printf("failed to Close() the TrapListener socket: %s", err)
}
select {
case <-t.done:
case <-time.After(t.CloseTimeout): // A timeout can prevent blocking forever
t.Params.Logger.Printf("timeout while awaiting done signal on TrapListener Close()")
}
}
}
// SendUDP sends a given SnmpPacket to the provided address using the currently opened connection.
func (t *TrapListener) SendUDP(packet *SnmpPacket, addr *net.UDPAddr) error {
ob, err := packet.marshalMsg()
if err != nil {
return fmt.Errorf("error marshaling SnmpPacket: %w", err)
}
// Send the return packet back.
count, err := t.conn.WriteTo(ob, addr)
if err != nil {
return fmt.Errorf("error sending SnmpPacket: %w", err)
}
// This isn't fatal, but should be logged.
if count != len(ob) {
t.Params.Logger.Printf("Failed to send all bytes of SnmpPacket!\n")
}
return nil
}
func (t *TrapListener) listenUDP(addr string) error {
// udp
udpAddr, err := net.ResolveUDPAddr(t.proto, addr)
if err != nil {
return err
}
t.conn, err = net.ListenUDP(udp, udpAddr)
if err != nil {
return err
}
defer t.conn.Close()
// Mark that we are listening now.
t.listening <- true
for {
switch {
case atomic.LoadInt32(&t.finish) == 1:
t.done <- true
return nil
default:
buf := make([]byte, t.buffSize)
rlen, remote, err := t.conn.ReadFromUDP(buf)
if err != nil {
if atomic.LoadInt32(&t.finish) == 1 {
// err most likely comes from reading from a closed connection
continue
}
t.Params.Logger.Printf("TrapListener: error in read %s\n", err)
continue
}
msg := buf[:rlen]
trap, err := t.Params.UnmarshalTrap(msg, false)
if err != nil {
t.Params.Logger.Printf("TrapListener: error in UnmarshalTrap %s\n", err)
continue
}
if trap.Version == Version3 && trap.SecurityModel == UserSecurityModel && t.Params.SecurityModel == UserSecurityModel {
securityParams, ok := t.Params.SecurityParameters.(*UsmSecurityParameters)
if !ok {
t.Params.Logger.Printf("TrapListener: Invalid SecurityParameters types")
}
packetSecurityParams, ok := trap.SecurityParameters.(*UsmSecurityParameters)
if !ok {
t.Params.Logger.Printf("TrapListener: Invalid SecurityParameters types")
}
snmpEngineID := securityParams.AuthoritativeEngineID
msgAuthoritativeEngineID := packetSecurityParams.AuthoritativeEngineID
if msgAuthoritativeEngineID != snmpEngineID {
if len(msgAuthoritativeEngineID) < 5 || len(msgAuthoritativeEngineID) > 32 {
// RFC3411 section 5. – SnmpEngineID definition.
// SnmpEngineID is an OCTET STRING which size should be between 5 and 32
// According to RFC3414 3.2.3b: stop processing and report
// the listener authoritative engine ID
atomic.AddUint32(&t.usmStatsUnknownEngineIDsCount, 1)
err := t.reportAuthoritativeEngineID(trap, snmpEngineID, remote)
if err != nil {
t.Params.Logger.Printf("TrapListener: %s\n", err)
}
continue
}
// RFC3414 3.2.3a: Continue processing
}
}
// Here we assume that t.OnNewTrap will not alter the contents
// of the PDU (per documentation, because Go does not have
// compile-time const checking). We don't pass a copy because
// the SnmpPacket type is somewhat large, but we could without
// violating any implicit or explicit spec.
t.OnNewTrap(trap, remote)
// If it was an Inform request, we need to send a response.
if trap.PDUType == InformRequest { //nolint:whitespace
// Reuse the packet, since we're supposed to send it back
// with the exact same variables unless there's an error.
// Change the PDUType to the response, though.
trap.PDUType = GetResponse
// If the response can be sent, the error-status is
// supposed to be set to noError and the error-index set to
// zero.
trap.Error = NoError
trap.ErrorIndex = 0
// TODO: Check that the message marshalled is not too large
// for the originator to accept and if so, send a tooBig
// error PDU per RFC3416 section 4.2.7. This maximum size,
// however, does not have a well-defined mechanism in the
// RFC other than using the path MTU (which is difficult to
// determine), so it's left to future implementations.
err := t.SendUDP(trap, remote)
if err != nil {
t.Params.Logger.Printf("TrapListener: %s\n", err)
}
}
}
}
}
func (t *TrapListener) reportAuthoritativeEngineID(trap *SnmpPacket, snmpEngineID string, addr *net.UDPAddr) error {
newSecurityParams, ok := trap.SecurityParameters.Copy().(*UsmSecurityParameters)
if !ok {
return errors.New("unable to cast SecurityParams to UsmSecurityParameters")
}
newSecurityParams.AuthoritativeEngineID = snmpEngineID
reportPacket := trap
reportPacket.PDUType = Report
reportPacket.MsgFlags &= AuthPriv
reportPacket.SecurityParameters = newSecurityParams
reportPacket.Variables = []SnmpPDU{
{
Name: usmStatsUnknownEngineIDs,
Value: int(atomic.LoadUint32(&t.usmStatsUnknownEngineIDsCount)),
Type: Integer,
},
}
return t.SendUDP(reportPacket, addr)
}
func (t *TrapListener) handleTCPRequest(conn net.Conn) {
// Make a buffer to hold incoming data.
buf := make([]byte, 4096)
// Read the incoming connection into the buffer.
reqLen, err := conn.Read(buf)
if err != nil {
t.Params.Logger.Printf("TrapListener: error in read %s\n", err)
return
}
msg := buf[:reqLen]
traps, err := t.Params.UnmarshalTrap(msg, false)
if err != nil {
t.Params.Logger.Printf("TrapListener: error in read %s\n", err)
return
}
// TODO: lying for backward compatibility reason - create UDP Address ... not nice
r, _ := net.ResolveUDPAddr("", conn.RemoteAddr().String())
t.OnNewTrap(traps, r)
// Close the connection when you're done with it.
conn.Close()
}
func (t *TrapListener) listenTCP(addr string) error {
tcpAddr, err := net.ResolveTCPAddr(t.proto, addr)
if err != nil {
return err
}
l, err := net.ListenTCP(tcp, tcpAddr)
if err != nil {
return err
}
defer l.Close()
// Mark that we are listening now.
t.listening <- true
for {
switch {
case atomic.LoadInt32(&t.finish) == 1:
t.done <- true
return nil
default:
// Listen for an incoming connection.
conn, err := l.Accept()
fmt.Printf("ACCEPT: %s", conn)
if err != nil {
fmt.Println("error accepting: ", err.Error())
return err
}
// Handle connections in a new goroutine.
go t.handleTCPRequest(conn)
}
}
}
// Listen listens on the UDP address addr and calls the OnNewTrap
// function specified in *TrapListener for every trap received.
//
// NOTE: the trap code is currently unreliable when working with snmpv3 - pull requests welcome
func (t *TrapListener) Listen(addr string) error {
if t.Params == nil {
t.Params = Default
}
// TODO TODO returning an error cause the following to hang/break
// TestSendTrapBasic
// TestSendTrapWithoutWaitingOnListen
// TestSendV1Trap
_ = t.Params.validateParameters()
if t.OnNewTrap == nil {
t.OnNewTrap = t.debugTrapHandler
}
splitted := strings.SplitN(addr, "://", 2)
t.proto = udp
if len(splitted) > 1 {
t.proto = splitted[0]
addr = splitted[1]
}
switch t.proto {
case tcp:
return t.listenTCP(addr)
case udp:
return t.listenUDP(addr)
default:
return fmt.Errorf("not implemented network protocol: %s [use: tcp/udp]", t.proto)
}
}
// Default trap handler
func (t *TrapListener) debugTrapHandler(s *SnmpPacket, u *net.UDPAddr) {
t.Params.Logger.Printf("got trapdata from %+v: %+v\n", u, s)
}
// UnmarshalTrap unpacks the SNMP Trap.
func (x *GoSNMP) UnmarshalTrap(trap []byte, useResponseSecurityParameters bool) (result *SnmpPacket, err error) {
// Get only the version from the header of the trap
version, _, err := x.unmarshalVersionFromHeader(trap, new(SnmpPacket))
if err != nil {
x.Logger.Printf("UnmarshalTrap version unmarshal: %s\n", err)
return nil, err
}
// If there are multiple users configured and the SNMP trap is v3, see which user has valid credentials
// by iterating through the list matching the identifier and seeing which credentials are authentic / can be used to decrypt
if x.TrapSecurityParametersTable != nil && version == Version3 {
identifier, err := x.getTrapIdentifier(trap)
if err != nil {
x.Logger.Printf("UnmarshalTrap V3 get trap identifier: %s\n", err)
return nil, err
}
secParamsList, err := x.TrapSecurityParametersTable.Get(identifier)
if err != nil {
x.Logger.Printf("UnmarshalTrap V3 get security parameters from table: %s\n", err)
return nil, err
}
for _, secParams := range secParamsList {
// Copy the trap and pass the security parameters to try to unmarshal with
cpTrap := make([]byte, len(trap))
copy(cpTrap, trap)
if result, err = x.unmarshalTrapBase(cpTrap, secParams.Copy(), true); err == nil {
return result, nil
}
}
return nil, fmt.Errorf("no credentials successfully unmarshaled trap: %w", err)
}
return x.unmarshalTrapBase(trap, nil, useResponseSecurityParameters)
}
func (x *GoSNMP) getTrapIdentifier(trap []byte) (string, error) {
// Initialize a packet with no auth/priv to unmarshal ID/key for security parameters to use
packet := new(SnmpPacket)
_, err := x.unmarshalHeader(trap, packet)
// Return err if no identifier was able to be parsed after unmarshaling
if err != nil && packet.SecurityParameters.getIdentifier() == "" {
return "", err
}
return packet.SecurityParameters.getIdentifier(), nil
}
func (x *GoSNMP) unmarshalTrapBase(trap []byte, sp SnmpV3SecurityParameters, useResponseSecurityParameters bool) (*SnmpPacket, error) {
result := new(SnmpPacket)
if x.SecurityParameters != nil && sp == nil {
err := x.SecurityParameters.InitSecurityKeys()
if err != nil {
return nil, err
}
result.SecurityParameters = x.SecurityParameters.Copy()
} else {
result.SecurityParameters = sp
}
cursor, err := x.unmarshalHeader(trap, result)
if err != nil {
x.Logger.Printf("UnmarshalTrap: %s\n", err)
return nil, err
}
if result.Version == Version3 {
if result.SecurityModel == UserSecurityModel {
err = x.testAuthentication(trap, result, useResponseSecurityParameters)
if err != nil {
x.Logger.Printf("UnmarshalTrap v3 auth: %s\n", err)
return nil, err
}
}
trap, cursor, err = x.decryptPacket(trap, cursor, result)
if err != nil {
x.Logger.Printf("UnmarshalTrap v3 decrypt: %s\n", err)
return nil, err
}
}
err = x.unmarshalPayload(trap, cursor, result)
if err != nil {
x.Logger.Printf("UnmarshalTrap: %s\n", err)
return nil, err
}
return result, nil
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gosnmp
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"runtime"
)
// SnmpV3MsgFlags contains various message flags to describe Authentication, Privacy, and whether a report PDU must be sent.
type SnmpV3MsgFlags uint8
// Possible values of SnmpV3MsgFlags
const (
NoAuthNoPriv SnmpV3MsgFlags = 0x0 // No authentication, and no privacy
AuthNoPriv SnmpV3MsgFlags = 0x1 // Authentication and no privacy
AuthPriv SnmpV3MsgFlags = 0x3 // Authentication and privacy
Reportable SnmpV3MsgFlags = 0x4 // Report PDU must be sent.
)
//go:generate stringer -type=SnmpV3MsgFlags
// SnmpV3SecurityModel describes the security model used by a SnmpV3 connection
type SnmpV3SecurityModel uint8
// UserSecurityModel is the only SnmpV3SecurityModel currently implemented.
const (
UserSecurityModel SnmpV3SecurityModel = 3
)
//go:generate stringer -type=SnmpV3SecurityModel
// SnmpV3SecurityParameters is a generic interface type to contain various implementations of SnmpV3SecurityParameters
type SnmpV3SecurityParameters interface {
Log()
Copy() SnmpV3SecurityParameters
Description() string
SafeString() string
InitPacket(packet *SnmpPacket) error
InitSecurityKeys() error
validate(flags SnmpV3MsgFlags) error
init(log Logger) error
discoveryRequired() *SnmpPacket
getDefaultContextEngineID() string
setSecurityParameters(in SnmpV3SecurityParameters) error
marshal(flags SnmpV3MsgFlags) ([]byte, error)
unmarshal(flags SnmpV3MsgFlags, packet []byte, cursor int) (int, error)
authenticate(packet []byte) error
isAuthentic(packetBytes []byte, packet *SnmpPacket) (bool, error)
encryptPacket(scopedPdu []byte) ([]byte, error)
decryptPacket(packet []byte, cursor int) ([]byte, error)
getIdentifier() string
getLogger() Logger
setLogger(log Logger)
}
func (x *GoSNMP) validateParametersV3() error {
// update following code if you implement a new security model
if x.SecurityModel != UserSecurityModel {
return errors.New("the SNMPV3 User Security Model is the only SNMPV3 security model currently implemented")
}
if x.SecurityParameters == nil {
return errors.New("SNMPV3 SecurityParameters must be set")
}
return x.SecurityParameters.validate(x.MsgFlags)
}
// authenticate the marshalled result of a snmp version 3 packet
func (packet *SnmpPacket) authenticate(msg []byte) ([]byte, error) {
defer func() {
if e := recover(); e != nil {
var buf = make([]byte, 8192)
runtime.Stack(buf, true)
fmt.Printf("[v3::authenticate]recover: %v. Stack=%v\n", e, string(buf))
}
}()
if packet.Version != Version3 {
return msg, nil
}
if packet.MsgFlags&AuthNoPriv > 0 {
err := packet.SecurityParameters.authenticate(msg)
if err != nil {
return nil, err
}
}
return msg, nil
}
func (x *GoSNMP) testAuthentication(packet []byte, result *SnmpPacket, useResponseSecurityParameters bool) error {
if x.Version != Version3 {
return fmt.Errorf("testAuthentication called with non Version3 connection")
}
msgFlags := x.MsgFlags
if useResponseSecurityParameters {
msgFlags = result.MsgFlags
}
// Special case for Engine Discovery (RFC3414 section 4) where we should
// skip authentication for the discovery packet with the special settings
// described in the RFC. The discovery package requires
msgSecParams := result.SecurityParameters.(*UsmSecurityParameters)
if msgFlags&NoAuthNoPriv == 0 && // NoAuthNoPriv method
msgSecParams.UserName == "" && // empty username
msgSecParams.AuthoritativeEngineID == "" && // empty authoritative engine ID
len(result.Variables) == 0 { // empty variable binding list
return nil
}
if msgFlags&AuthNoPriv > 0 {
var authentic bool
var err error
if useResponseSecurityParameters {
authentic, err = result.SecurityParameters.isAuthentic(packet, result)
} else {
authentic, err = x.SecurityParameters.isAuthentic(packet, result)
}
if err != nil {
return err
}
if !authentic {
return fmt.Errorf("incoming packet is not authentic, discarding")
}
}
return nil
}
func (x *GoSNMP) initPacket(packetOut *SnmpPacket) error {
if x.MsgFlags&AuthPriv > AuthNoPriv {
return x.SecurityParameters.InitPacket(packetOut)
}
return nil
}
// http://tools.ietf.org/html/rfc2574#section-2.2.3 This code does not
// check if the last message received was more than 150 seconds ago The
// snmpds that this code was tested on emit an 'out of time window'
// error with the new time and this code will retransmit when that is
// received.
func (x *GoSNMP) negotiateInitialSecurityParameters(packetOut *SnmpPacket) error {
if x.Version != Version3 || packetOut.Version != Version3 {
return fmt.Errorf("negotiateInitialSecurityParameters called with non Version3 connection or packet")
}
if x.SecurityModel != packetOut.SecurityModel {
return fmt.Errorf("connection security model does not match security model defined in packet")
}
if discoveryPacket := packetOut.SecurityParameters.discoveryRequired(); discoveryPacket != nil {
discoveryPacket.ContextName = x.ContextName
result, err := x.sendOneRequest(discoveryPacket, true)
if err != nil {
return err
}
err = x.storeSecurityParameters(result)
if err != nil {
return err
}
err = x.updatePktSecurityParameters(packetOut)
if err != nil {
return err
}
} else {
err := packetOut.SecurityParameters.InitSecurityKeys()
if err == nil {
return err
}
}
return nil
}
// save the connection security parameters after a request/response
func (x *GoSNMP) storeSecurityParameters(result *SnmpPacket) error {
if x.Version != Version3 || result.Version != Version3 {
return fmt.Errorf("storeParameters called with non Version3 connection or packet")
}
if x.SecurityModel != result.SecurityModel {
return fmt.Errorf("connection security model does not match security model extracted from packet")
}
if x.ContextEngineID == "" {
x.ContextEngineID = result.SecurityParameters.getDefaultContextEngineID()
}
return x.SecurityParameters.setSecurityParameters(result.SecurityParameters)
}
// update packet security parameters to match connection security parameters
func (x *GoSNMP) updatePktSecurityParameters(packetOut *SnmpPacket) error {
if x.Version != Version3 || packetOut.Version != Version3 {
return fmt.Errorf("updatePktSecurityParameters called with non Version3 connection or packet")
}
if x.SecurityModel != packetOut.SecurityModel {
return fmt.Errorf("connection security model does not match security model extracted from packet")
}
err := packetOut.SecurityParameters.setSecurityParameters(x.SecurityParameters)
if err != nil {
return err
}
if packetOut.ContextEngineID == "" {
packetOut.ContextEngineID = x.ContextEngineID
}
return nil
}
func (packet *SnmpPacket) marshalV3(buf *bytes.Buffer) (*bytes.Buffer, error) {
emptyBuffer := new(bytes.Buffer) // used when returning errors
header, err := packet.marshalV3Header()
if err != nil {
return emptyBuffer, err
}
buf.Write([]byte{byte(Sequence), byte(len(header))})
packet.Logger.Printf("Marshal V3 Header len=%d. Eaten Last 4 Bytes=%v", len(header), header[len(header)-4:])
buf.Write(header)
var securityParameters []byte
securityParameters, err = packet.SecurityParameters.marshal(packet.MsgFlags)
if err != nil {
return emptyBuffer, err
}
packet.Logger.Printf("Marshal V3 SecurityParameters len=%d. Eaten Last 4 Bytes=%v",
len(securityParameters), securityParameters[len(securityParameters)-4:])
buf.Write([]byte{byte(OctetString)})
secParamLen, err := marshalLength(len(securityParameters))
if err != nil {
return emptyBuffer, err
}
buf.Write(secParamLen)
buf.Write(securityParameters)
scopedPdu, err := packet.marshalV3ScopedPDU()
if err != nil {
return emptyBuffer, err
}
buf.Write(scopedPdu)
return buf, nil
}
// marshal a snmp version 3 packet header
func (packet *SnmpPacket) marshalV3Header() ([]byte, error) {
buf := new(bytes.Buffer)
// msg id
buf.Write([]byte{byte(Integer), 4})
err := binary.Write(buf, binary.BigEndian, packet.MsgID)
if err != nil {
return nil, err
}
oldLen := 0
packet.Logger.Printf("MarshalV3Header msgID len=%v", buf.Len()-oldLen)
oldLen = buf.Len()
// maximum response msg size
var maxBufSize uint32 = rxBufSize
if packet.MsgMaxSize != 0 {
maxBufSize = packet.MsgMaxSize
}
maxmsgsize, err := marshalUint32(maxBufSize)
if err != nil {
return nil, err
}
buf.Write([]byte{byte(Integer), byte(len(maxmsgsize))})
buf.Write(maxmsgsize)
packet.Logger.Printf("MarshalV3Header maxmsgsize len=%v", buf.Len()-oldLen)
oldLen = buf.Len()
// msg flags
buf.Write([]byte{byte(OctetString), 1, byte(packet.MsgFlags)})
packet.Logger.Printf("MarshalV3Header msg flags len=%v", buf.Len()-oldLen)
oldLen = buf.Len()
// msg security model
buf.Write([]byte{byte(Integer), 1, byte(packet.SecurityModel)})
packet.Logger.Printf("MarshalV3Header msg security model len=%v", buf.Len()-oldLen)
return buf.Bytes(), nil
}
// marshal and encrypt (if necessary) a snmp version 3 Scoped PDU
func (packet *SnmpPacket) marshalV3ScopedPDU() ([]byte, error) {
var b []byte
scopedPdu, err := packet.prepareV3ScopedPDU()
if err != nil {
return nil, err
}
pduLen, err := marshalLength(len(scopedPdu))
if err != nil {
return nil, err
}
b = append([]byte{byte(Sequence)}, pduLen...)
scopedPdu = append(b, scopedPdu...)
if packet.MsgFlags&AuthPriv > AuthNoPriv {
scopedPdu, err = packet.SecurityParameters.encryptPacket(scopedPdu)
if err != nil {
return nil, err
}
}
return scopedPdu, nil
}
// prepare the plain text of a snmp version 3 Scoped PDU
func (packet *SnmpPacket) prepareV3ScopedPDU() ([]byte, error) {
var buf bytes.Buffer
// ContextEngineID
idlen, err := marshalLength(len(packet.ContextEngineID))
if err != nil {
return nil, err
}
buf.Write(append([]byte{byte(OctetString)}, idlen...))
buf.WriteString(packet.ContextEngineID)
// ContextName
namelen, err := marshalLength(len(packet.ContextName))
if err != nil {
return nil, err
}
buf.Write(append([]byte{byte(OctetString)}, namelen...))
buf.WriteString(packet.ContextName)
data, err := packet.marshalPDU()
if err != nil {
return nil, err
}
buf.Write(data)
return buf.Bytes(), nil
}
func (x *GoSNMP) unmarshalV3Header(packet []byte,
cursor int,
response *SnmpPacket) (int, error) {
if PDUType(packet[cursor]) != Sequence {
return 0, fmt.Errorf("invalid SNMPV3 Header")
}
_, cursorTmp, err := parseLength(packet[cursor:])
if err != nil {
return 0, err
}
cursor += cursorTmp
if cursor > len(packet) {
return 0, errors.New("error parsing SNMPV3 message ID: truncted packet")
}
rawMsgID, count, err := parseRawField(x.Logger, packet[cursor:], "msgID")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 message ID: %w", err)
}
cursor += count
if cursor > len(packet) {
return 0, errors.New("error parsing SNMPV3 message ID: truncted packet")
}
if MsgID, ok := rawMsgID.(int); ok {
response.MsgID = uint32(MsgID) //nolint:gosec
x.Logger.Printf("Parsed message ID %d", MsgID)
}
rawMsgMaxSize, count, err := parseRawField(x.Logger, packet[cursor:], "msgMaxSize")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 msgMaxSize: %w", err)
}
cursor += count
if cursor > len(packet) {
return 0, errors.New("error parsing SNMPV3 message ID: truncted packet")
}
if MsgMaxSize, ok := rawMsgMaxSize.(int); ok {
response.MsgMaxSize = uint32(MsgMaxSize) //nolint:gosec
x.Logger.Printf("Parsed message max size %d", MsgMaxSize)
}
rawMsgFlags, count, err := parseRawField(x.Logger, packet[cursor:], "msgFlags")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 msgFlags: %w", err)
}
cursor += count
if cursor > len(packet) {
return 0, errors.New("error parsing SNMPV3 message ID: truncted packet")
}
if MsgFlags, ok := rawMsgFlags.(string); ok && len(MsgFlags) > 0 {
response.MsgFlags = SnmpV3MsgFlags(MsgFlags[0])
x.Logger.Printf("parsed msg flags %s", MsgFlags)
}
rawSecModel, count, err := parseRawField(x.Logger, packet[cursor:], "msgSecurityModel")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 msgSecModel: %w", err)
}
cursor += count
if cursor >= len(packet) {
return 0, errors.New("error parsing SNMPV3 message ID: truncted packet")
}
if SecModel, ok := rawSecModel.(int); ok {
response.SecurityModel = SnmpV3SecurityModel(SecModel) //nolint:gosec
x.Logger.Printf("Parsed security model %d", SecModel)
}
if PDUType(packet[cursor]) != PDUType(OctetString) {
return 0, errors.New("invalid SNMPV3 Security Parameters")
}
_, cursorTmp, err = parseLength(packet[cursor:])
if err != nil {
return 0, err
}
cursor += cursorTmp
if cursor > len(packet) {
return 0, errors.New("error parsing SNMPV3 message ID: truncted packet")
}
if response.SecurityParameters == nil {
response.SecurityParameters = &UsmSecurityParameters{Logger: x.Logger}
}
cursor, err = response.SecurityParameters.unmarshal(response.MsgFlags, packet, cursor)
if err != nil {
return 0, err
}
x.Logger.Printf("Parsed Security Parameters. now offset=%v,", cursor)
return cursor, nil
}
func (x *GoSNMP) decryptPacket(packet []byte, cursor int, response *SnmpPacket) ([]byte, int, error) {
var err error
var decrypted = false
if cursor >= len(packet) {
return nil, 0, errors.New("error parsing SNMPV3: truncated packet")
}
switch PDUType(packet[cursor]) {
case PDUType(OctetString):
// pdu is encrypted
packet, err = response.SecurityParameters.decryptPacket(packet, cursor)
if err != nil {
return nil, 0, err
}
decrypted = true
fallthrough
case Sequence:
// pdu is plaintext or has been decrypted
tlength, cursorTmp, err := parseLength(packet[cursor:])
if err != nil {
return nil, 0, err
}
if decrypted {
// truncate padding that might have been included with
// the encrypted PDU
if cursor+tlength > len(packet) {
return nil, 0, errors.New("error parsing SNMPV3: truncated packet")
}
packet = packet[:cursor+tlength]
}
cursor += cursorTmp
if cursor > len(packet) {
return nil, 0, errors.New("error parsing SNMPV3: truncated packet")
}
rawContextEngineID, count, err := parseRawField(x.Logger, packet[cursor:], "contextEngineID")
if err != nil {
return nil, 0, fmt.Errorf("error parsing SNMPV3 contextEngineID: %w", err)
}
cursor += count
if cursor > len(packet) {
return nil, 0, errors.New("error parsing SNMPV3: truncated packet")
}
if contextEngineID, ok := rawContextEngineID.(string); ok {
response.ContextEngineID = contextEngineID
x.Logger.Printf("Parsed contextEngineID %s", contextEngineID)
}
rawContextName, count, err := parseRawField(x.Logger, packet[cursor:], "contextName")
if err != nil {
return nil, 0, fmt.Errorf("error parsing SNMPV3 contextName: %w", err)
}
cursor += count
if cursor > len(packet) {
return nil, 0, errors.New("error parsing SNMPV3: truncated packet")
}
if contextName, ok := rawContextName.(string); ok {
response.ContextName = contextName
x.Logger.Printf("Parsed contextName %s", contextName)
}
default:
return nil, 0, errors.New("error parsing SNMPV3 scoped PDU")
}
return packet, cursor, nil
}
// Copyright 2023 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
package gosnmp
import (
"fmt"
"sync"
)
// SnmpV3SecurityParametersTable is a mapping of identifiers to corresponding SNMP V3 Security Model parameters
type SnmpV3SecurityParametersTable struct {
table map[string][]SnmpV3SecurityParameters
Logger Logger
mu sync.RWMutex
}
func NewSnmpV3SecurityParametersTable(logger Logger) *SnmpV3SecurityParametersTable {
return &SnmpV3SecurityParametersTable{
table: make(map[string][]SnmpV3SecurityParameters),
Logger: logger,
}
}
func (spm *SnmpV3SecurityParametersTable) Add(key string, sp SnmpV3SecurityParameters) error {
spm.mu.Lock()
defer spm.mu.Unlock()
if err := sp.InitSecurityKeys(); err != nil {
return err
}
// If no logger is set for the security params (empty struct), use the one from the table
if (Logger{}) == sp.getLogger() {
sp.setLogger(spm.Logger)
}
spm.table[key] = append(spm.table[key], sp)
spm.Logger.Printf("Added security parameters %s for key: %s", sp.SafeString(), key)
return nil
}
func (spm *SnmpV3SecurityParametersTable) Get(key string) ([]SnmpV3SecurityParameters, error) {
spm.mu.RLock()
defer spm.mu.RUnlock()
if sp, ok := spm.table[key]; ok {
return sp, nil
}
return nil, fmt.Errorf("no security parameters found for the key %s", key)
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gosnmp
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/hmac"
"crypto/md5"
crand "crypto/rand"
"crypto/sha1"
_ "crypto/sha256" // Register hash function #4 (SHA224), #5 (SHA256)
_ "crypto/sha512" // Register hash function #6 (SHA384), #7 (SHA512)
"crypto/subtle"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"hash"
"strings"
"sync"
"sync/atomic"
)
// SnmpV3AuthProtocol describes the authentication protocol in use by an authenticated SnmpV3 connection.
type SnmpV3AuthProtocol uint8
// NoAuth, MD5, and SHA are implemented
const (
NoAuth SnmpV3AuthProtocol = 1
MD5 SnmpV3AuthProtocol = 2
SHA SnmpV3AuthProtocol = 3
SHA224 SnmpV3AuthProtocol = 4
SHA256 SnmpV3AuthProtocol = 5
SHA384 SnmpV3AuthProtocol = 6
SHA512 SnmpV3AuthProtocol = 7
)
//go:generate stringer -type=SnmpV3AuthProtocol
// HashType maps the AuthProtocol's hash type to an actual crypto.Hash object.
func (authProtocol SnmpV3AuthProtocol) HashType() crypto.Hash {
switch authProtocol {
default:
return crypto.MD5
case SHA:
return crypto.SHA1
case SHA224:
return crypto.SHA224
case SHA256:
return crypto.SHA256
case SHA384:
return crypto.SHA384
case SHA512:
return crypto.SHA512
}
}
//nolint:gochecknoglobals
var macVarbinds = [][]byte{
{}, // dummy
{byte(OctetString), 0}, // NoAuth
{byte(OctetString), 12, // MD5
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0},
{byte(OctetString), 12, // SHA
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0},
{byte(OctetString), 16, // SHA224
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0},
{byte(OctetString), 24, // SHA256
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0},
{byte(OctetString), 32, // SHA384
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0},
{byte(OctetString), 48, // SHA512
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0}}
// SnmpV3PrivProtocol is the privacy protocol in use by an private SnmpV3 connection.
type SnmpV3PrivProtocol uint8
// NoPriv, DES implemented, AES planned
// Changed: AES192, AES256, AES192C, AES256C added
const (
NoPriv SnmpV3PrivProtocol = 1
DES SnmpV3PrivProtocol = 2
AES SnmpV3PrivProtocol = 3
AES192 SnmpV3PrivProtocol = 4 // Blumenthal-AES192
AES256 SnmpV3PrivProtocol = 5 // Blumenthal-AES256
AES192C SnmpV3PrivProtocol = 6 // Reeder-AES192
AES256C SnmpV3PrivProtocol = 7 // Reeder-AES256
)
//go:generate stringer -type=SnmpV3PrivProtocol
// UsmSecurityParameters is an implementation of SnmpV3SecurityParameters for the UserSecurityModel
type UsmSecurityParameters struct {
mu sync.Mutex
// localAESSalt must be 64bit aligned to use with atomic operations.
localAESSalt uint64
localDESSalt uint32
AuthoritativeEngineID string
AuthoritativeEngineBoots uint32
AuthoritativeEngineTime uint32
UserName string
AuthenticationParameters string
PrivacyParameters []byte
AuthenticationProtocol SnmpV3AuthProtocol
PrivacyProtocol SnmpV3PrivProtocol
AuthenticationPassphrase string
PrivacyPassphrase string
SecretKey []byte
PrivacyKey []byte
Logger Logger
}
func (sp *UsmSecurityParameters) getIdentifier() string {
return sp.UserName
}
func (sp *UsmSecurityParameters) getLogger() Logger {
return sp.Logger
}
func (sp *UsmSecurityParameters) setLogger(log Logger) {
sp.Logger = log
}
// Description logs authentication paramater information to the provided GoSNMP Logger
func (sp *UsmSecurityParameters) Description() string {
var sb strings.Builder
sb.WriteString("user=")
sb.WriteString(sp.UserName)
sb.WriteString(",engine=(")
sb.WriteString(hex.EncodeToString([]byte(sp.AuthoritativeEngineID)))
// sb.WriteString(sp.AuthoritativeEngineID)
sb.WriteString(")")
switch sp.AuthenticationProtocol {
case NoAuth:
sb.WriteString(",auth=noauth")
case MD5:
sb.WriteString(",auth=md5")
case SHA:
sb.WriteString(",auth=sha")
case SHA224:
sb.WriteString(",auth=sha224")
case SHA256:
sb.WriteString(",auth=sha256")
case SHA384:
sb.WriteString(",auth=sha384")
case SHA512:
sb.WriteString(",auth=sha512")
}
sb.WriteString(",authPass=")
sb.WriteString(sp.AuthenticationPassphrase)
switch sp.PrivacyProtocol {
case NoPriv:
sb.WriteString(",priv=NoPriv")
case DES:
sb.WriteString(",priv=DES")
case AES:
sb.WriteString(",priv=AES")
case AES192:
sb.WriteString(",priv=AES192")
case AES256:
sb.WriteString(",priv=AES256")
case AES192C:
sb.WriteString(",priv=AES192C")
case AES256C:
sb.WriteString(",priv=AES256C")
}
sb.WriteString(",privPass=")
sb.WriteString(sp.PrivacyPassphrase)
return sb.String()
}
// SafeString returns a logging safe (no secrets) string of the UsmSecurityParameters
func (sp *UsmSecurityParameters) SafeString() string {
return fmt.Sprintf("AuthoritativeEngineID:%s, AuthoritativeEngineBoots:%d, AuthoritativeEngineTimes:%d, UserName:%s, AuthenticationParameters:%s, PrivacyParameters:%v, AuthenticationProtocol:%s, PrivacyProtocol:%s",
sp.AuthoritativeEngineID,
sp.AuthoritativeEngineBoots,
sp.AuthoritativeEngineTime,
sp.UserName,
sp.AuthenticationParameters,
sp.PrivacyParameters,
sp.AuthenticationProtocol,
sp.PrivacyProtocol,
)
}
// Log logs security paramater information to the provided GoSNMP Logger
func (sp *UsmSecurityParameters) Log() {
sp.mu.Lock()
defer sp.mu.Unlock()
sp.Logger.Printf("SECURITY PARAMETERS:%s", sp.SafeString())
}
// Copy method for UsmSecurityParameters used to copy a SnmpV3SecurityParameters without knowing it's implementation
func (sp *UsmSecurityParameters) Copy() SnmpV3SecurityParameters {
sp.mu.Lock()
defer sp.mu.Unlock()
return &UsmSecurityParameters{AuthoritativeEngineID: sp.AuthoritativeEngineID,
AuthoritativeEngineBoots: sp.AuthoritativeEngineBoots,
AuthoritativeEngineTime: sp.AuthoritativeEngineTime,
UserName: sp.UserName,
AuthenticationParameters: sp.AuthenticationParameters,
PrivacyParameters: sp.PrivacyParameters,
AuthenticationProtocol: sp.AuthenticationProtocol,
PrivacyProtocol: sp.PrivacyProtocol,
AuthenticationPassphrase: sp.AuthenticationPassphrase,
PrivacyPassphrase: sp.PrivacyPassphrase,
SecretKey: sp.SecretKey,
PrivacyKey: sp.PrivacyKey,
localDESSalt: sp.localDESSalt,
localAESSalt: sp.localAESSalt,
Logger: sp.Logger,
}
}
func (sp *UsmSecurityParameters) getDefaultContextEngineID() string {
return sp.AuthoritativeEngineID
}
// InitSecurityKeys initializes the Priv and Auth keys if needed
func (sp *UsmSecurityParameters) InitSecurityKeys() error {
sp.mu.Lock()
defer sp.mu.Unlock()
return sp.initSecurityKeysNoLock()
}
func (sp *UsmSecurityParameters) initSecurityKeysNoLock() error {
var err error
if sp.AuthenticationProtocol > NoAuth && len(sp.SecretKey) == 0 {
sp.SecretKey, err = genlocalkey(sp.AuthenticationProtocol,
sp.AuthenticationPassphrase,
sp.AuthoritativeEngineID)
if err != nil {
return err
}
}
if sp.PrivacyProtocol > NoPriv && len(sp.PrivacyKey) == 0 {
switch sp.PrivacyProtocol {
// Changed: The Output of SHA1 is a 20 octets array, therefore for AES128 (16 octets) either key extension algorithm can be used.
case AES, AES192, AES256, AES192C, AES256C:
// Use abstract AES key localization algorithms.
sp.PrivacyKey, err = genlocalPrivKey(sp.PrivacyProtocol, sp.AuthenticationProtocol,
sp.PrivacyPassphrase,
sp.AuthoritativeEngineID)
if err != nil {
return err
}
default:
sp.PrivacyKey, err = genlocalkey(sp.AuthenticationProtocol,
sp.PrivacyPassphrase,
sp.AuthoritativeEngineID)
if err != nil {
return err
}
}
}
return nil
}
func (sp *UsmSecurityParameters) setSecurityParameters(in SnmpV3SecurityParameters) error {
var insp *UsmSecurityParameters
var err error
sp.mu.Lock()
defer sp.mu.Unlock()
if insp, err = castUsmSecParams(in); err != nil {
return err
}
if sp.AuthoritativeEngineID != insp.AuthoritativeEngineID {
sp.AuthoritativeEngineID = insp.AuthoritativeEngineID
sp.SecretKey = nil
sp.PrivacyKey = nil
err = sp.initSecurityKeysNoLock()
if err != nil {
return err
}
}
sp.AuthoritativeEngineBoots = insp.AuthoritativeEngineBoots
sp.AuthoritativeEngineTime = insp.AuthoritativeEngineTime
return nil
}
func (sp *UsmSecurityParameters) validate(flags SnmpV3MsgFlags) error {
securityLevel := flags & AuthPriv // isolate flags that determine security level
switch securityLevel {
case AuthPriv:
if sp.PrivacyProtocol <= NoPriv {
return fmt.Errorf("securityParameters.PrivacyProtocol is required")
}
fallthrough
case AuthNoPriv:
if sp.AuthenticationProtocol <= NoAuth {
return fmt.Errorf("securityParameters.AuthenticationProtocol is required")
}
fallthrough
case NoAuthNoPriv:
if sp.UserName == "" {
return fmt.Errorf("securityParameters.UserName is required")
}
default:
return fmt.Errorf("validate: MsgFlags must be populated with an appropriate security level")
}
if sp.PrivacyProtocol > NoPriv && len(sp.PrivacyKey) == 0 {
if sp.PrivacyPassphrase == "" {
return fmt.Errorf("securityParameters.PrivacyPassphrase is required when a privacy protocol is specified")
}
}
if sp.AuthenticationProtocol > NoAuth && len(sp.SecretKey) == 0 {
if sp.AuthenticationPassphrase == "" {
return fmt.Errorf("securityParameters.AuthenticationPassphrase is required when an authentication protocol is specified")
}
}
return nil
}
func (sp *UsmSecurityParameters) init(log Logger) error {
var err error
sp.Logger = log
switch sp.PrivacyProtocol {
case AES, AES192, AES256, AES192C, AES256C:
salt := make([]byte, 8)
_, err = crand.Read(salt)
if err != nil {
return fmt.Errorf("error creating a cryptographically secure salt: %w", err)
}
sp.localAESSalt = binary.BigEndian.Uint64(salt)
case DES:
salt := make([]byte, 4)
_, err = crand.Read(salt)
if err != nil {
return fmt.Errorf("error creating a cryptographically secure salt: %w", err)
}
sp.localDESSalt = binary.BigEndian.Uint32(salt)
}
return nil
}
func castUsmSecParams(secParams SnmpV3SecurityParameters) (*UsmSecurityParameters, error) {
s, ok := secParams.(*UsmSecurityParameters)
if !ok || s == nil {
return nil, fmt.Errorf("param SnmpV3SecurityParameters is not of type *UsmSecurityParameters")
}
return s, nil
}
var (
passwordKeyHashCache = make(map[string][]byte) //nolint:gochecknoglobals
passwordKeyHashMutex sync.RWMutex //nolint:gochecknoglobals
passwordCacheDisable atomic.Bool //nolint:gochecknoglobals
)
// PasswordCaching is enabled by default for performance reason. If the cache was disabled then
// re-enabled, the cache is reset.
func PasswordCaching(enable bool) {
oldCacheEnable := !passwordCacheDisable.Load()
passwordKeyHashMutex.Lock()
if !enable { // if off
passwordKeyHashCache = nil
} else if !oldCacheEnable && enable { // if off then on
passwordKeyHashCache = make(map[string][]byte)
}
passwordCacheDisable.Store(!enable)
passwordKeyHashMutex.Unlock()
}
func hashPassword(hash hash.Hash, password string) ([]byte, error) {
if len(password) == 0 {
return []byte{}, errors.New("hashPassword: password is empty")
}
var pi int // password index
for i := 0; i < 1048576; i += 64 {
var chunk []byte
for range 64 {
chunk = append(chunk, password[pi%len(password)])
pi++
}
if _, err := hash.Write(chunk); err != nil {
return []byte{}, err
}
}
hashed := hash.Sum(nil)
return hashed, nil
}
// Common passwordToKey algorithm, "caches" the result to avoid extra computation each reuse
func cachedPasswordToKey(hash hash.Hash, cacheKey string, password string) ([]byte, error) {
cacheDisable := passwordCacheDisable.Load()
if !cacheDisable {
passwordKeyHashMutex.RLock()
value := passwordKeyHashCache[cacheKey]
passwordKeyHashMutex.RUnlock()
if value != nil {
return value, nil
}
}
hashed, err := hashPassword(hash, password)
if err != nil {
return nil, err
}
if !cacheDisable {
passwordKeyHashMutex.Lock()
passwordKeyHashCache[cacheKey] = hashed
passwordKeyHashMutex.Unlock()
}
return hashed, nil
}
func hMAC(hash crypto.Hash, cacheKey string, password string, engineID string) ([]byte, error) {
hashed, err := cachedPasswordToKey(hash.New(), cacheKey, password)
if err != nil {
return []byte{}, nil
}
local := hash.New()
_, err = local.Write(hashed)
if err != nil {
return []byte{}, err
}
_, err = local.Write([]byte(engineID))
if err != nil {
return []byte{}, err
}
_, err = local.Write(hashed)
if err != nil {
return []byte{}, err
}
final := local.Sum(nil)
return final, nil
}
func cacheKey(authProtocol SnmpV3AuthProtocol, passphrase string) string {
if passwordCacheDisable.Load() {
return ""
}
var cacheKey = make([]byte, 1+len(passphrase))
cacheKey = append(cacheKey, 'h'+byte(authProtocol))
cacheKey = append(cacheKey, []byte(passphrase)...)
return string(cacheKey)
}
// Extending the localized privacy key according to Reeder Key extension algorithm:
// https://tools.ietf.org/html/draft-reeder-snmpv3-usm-3dese
// Many vendors, including Cisco, use the 3DES key extension algorithm to extend the privacy keys that are too short when using AES,AES192 and AES256.
// Previously implemented in net-snmp and pysnmp libraries.
// Tested for AES128 and AES256
func extendKeyReeder(authProtocol SnmpV3AuthProtocol, password string, engineID string) ([]byte, error) {
var key []byte
var err error
key, err = hMAC(authProtocol.HashType(), cacheKey(authProtocol, password), password, engineID)
if err != nil {
return nil, err
}
newkey, err := hMAC(authProtocol.HashType(), cacheKey(authProtocol, string(key)), string(key), engineID)
return append(key, newkey...), err
}
// Extending the localized privacy key according to Blumenthal key extension algorithm:
// https://tools.ietf.org/html/draft-blumenthal-aes-usm-04#page-7
// Not many vendors use this algorithm.
// Previously implemented in the net-snmp and pysnmp libraries.
// TODO: Not tested
func extendKeyBlumenthal(authProtocol SnmpV3AuthProtocol, password string, engineID string) ([]byte, error) {
var key []byte
var err error
key, err = hMAC(authProtocol.HashType(), cacheKey(authProtocol, password), password, engineID)
if err != nil {
return nil, err
}
newkey := authProtocol.HashType().New()
_, _ = newkey.Write(key)
return append(key, newkey.Sum(nil)...), err
}
// Changed: New function to calculate the Privacy Key for abstract AES
func genlocalPrivKey(privProtocol SnmpV3PrivProtocol, authProtocol SnmpV3AuthProtocol, password string, engineID string) ([]byte, error) {
var keylen int
var localPrivKey []byte
var err error
switch privProtocol {
case AES, DES:
keylen = 16
case AES192, AES192C:
keylen = 24
case AES256, AES256C:
keylen = 32
}
switch privProtocol {
case AES, AES192C, AES256C:
localPrivKey, err = extendKeyReeder(authProtocol, password, engineID)
case AES192, AES256:
localPrivKey, err = extendKeyBlumenthal(authProtocol, password, engineID)
default:
localPrivKey, err = genlocalkey(authProtocol, password, engineID)
}
if err != nil {
return nil, err
}
if len(localPrivKey) < keylen {
return []byte{}, fmt.Errorf("genlocalPrivKey: privProtocol: %v len(localPrivKey): %d, keylen: %d",
privProtocol, len(localPrivKey), keylen)
}
return localPrivKey[:keylen], nil
}
func genlocalkey(authProtocol SnmpV3AuthProtocol, passphrase string, engineID string) ([]byte, error) {
var secretKey []byte
var err error
secretKey, err = hMAC(authProtocol.HashType(), cacheKey(authProtocol, passphrase), passphrase, engineID)
if err != nil {
return []byte{}, err
}
return secretKey, nil
}
// http://tools.ietf.org/html/rfc2574#section-8.1.1.1
// localDESSalt needs to be incremented on every packet.
func (sp *UsmSecurityParameters) usmAllocateNewSalt() any {
sp.mu.Lock()
defer sp.mu.Unlock()
var newSalt any
switch sp.PrivacyProtocol {
case AES, AES192, AES256, AES192C, AES256C:
newSalt = atomic.AddUint64(&(sp.localAESSalt), 1)
default:
newSalt = atomic.AddUint32(&(sp.localDESSalt), 1)
}
return newSalt
}
func (sp *UsmSecurityParameters) usmSetSalt(newSalt any) error {
sp.mu.Lock()
defer sp.mu.Unlock()
switch sp.PrivacyProtocol {
case AES, AES192, AES256, AES192C, AES256C:
aesSalt, ok := newSalt.(uint64)
if !ok {
return fmt.Errorf("salt provided to usmSetSalt is not the correct type for the AES privacy protocol")
}
var salt = make([]byte, 8)
binary.BigEndian.PutUint64(salt, aesSalt)
sp.PrivacyParameters = salt
default:
desSalt, ok := newSalt.(uint32)
if !ok {
return fmt.Errorf("salt provided to usmSetSalt is not the correct type for the DES privacy protocol")
}
var salt = make([]byte, 8)
binary.BigEndian.PutUint32(salt, sp.AuthoritativeEngineBoots)
binary.BigEndian.PutUint32(salt[4:], desSalt)
sp.PrivacyParameters = salt
}
return nil
}
// InitPacket ensures the enc salt is incremented for packets marked for AuthPriv
func (sp *UsmSecurityParameters) InitPacket(packet *SnmpPacket) error {
// http://tools.ietf.org/html/rfc2574#section-8.1.1.1
// localDESSalt needs to be incremented on every packet.
newSalt := sp.usmAllocateNewSalt()
if packet.MsgFlags&AuthPriv > AuthNoPriv {
s, err := castUsmSecParams(packet.SecurityParameters)
if err != nil {
return err
}
return s.usmSetSalt(newSalt)
}
return nil
}
func (sp *UsmSecurityParameters) discoveryRequired() *SnmpPacket {
if sp.AuthoritativeEngineID == "" {
var emptyPdus []SnmpPDU
// send blank packet to discover authoriative engine ID/boots/time
blankPacket := &SnmpPacket{
Version: Version3,
MsgFlags: Reportable | NoAuthNoPriv,
SecurityModel: UserSecurityModel,
SecurityParameters: &UsmSecurityParameters{Logger: sp.Logger},
PDUType: GetRequest,
Logger: sp.Logger,
Variables: emptyPdus,
}
return blankPacket
}
return nil
}
func (sp *UsmSecurityParameters) calcPacketDigest(packet []byte) ([]byte, error) {
return calcPacketDigest(packet, sp)
}
// calcPacketDigest calculate authenticate digest for incoming messages (TRAP or
// INFORM).
// Support MD5, SHA1, SHA224, SHA256, SHA384, SHA512 protocols
func calcPacketDigest(packetBytes []byte, secParams *UsmSecurityParameters) ([]byte, error) {
var digest []byte
var err error
switch secParams.AuthenticationProtocol {
case MD5, SHA:
digest, err = digestRFC3414(
secParams.AuthenticationProtocol,
packetBytes,
secParams.SecretKey)
case SHA224, SHA256, SHA384, SHA512:
digest, err = digestRFC7860(
secParams.AuthenticationProtocol,
packetBytes,
secParams.SecretKey)
}
if err != nil {
return nil, err
}
digest = digest[:len(macVarbinds[secParams.AuthenticationProtocol])-2]
return digest, nil
}
// digestRFC7860 calculate digest for incoming messages using HMAC-SHA2 protcols
// according to RFC7860 4.2.2
func digestRFC7860(h SnmpV3AuthProtocol, packet []byte, authKey []byte) ([]byte, error) {
mac := hmac.New(h.HashType().New, authKey)
_, err := mac.Write(packet)
if err != nil {
return []byte{}, err
}
msgDigest := mac.Sum(nil)
return msgDigest, nil
}
// digestRFC3414 calculate digest for incoming messages using MD5 or SHA1
// according to RFC3414 6.3.2 and 7.3.2
func digestRFC3414(h SnmpV3AuthProtocol, packet []byte, authKey []byte) ([]byte, error) {
var extkey [64]byte
var err error
var k1, k2 [64]byte
var h1, h2 hash.Hash
copy(extkey[:], authKey)
switch h {
case MD5:
h1 = md5.New() //nolint:gosec
h2 = md5.New() //nolint:gosec
case SHA:
h1 = sha1.New() //nolint:gosec
h2 = sha1.New() //nolint:gosec
}
for i := range 64 {
k1[i] = extkey[i] ^ 0x36 //nolint:gosec
k2[i] = extkey[i] ^ 0x5c //nolint:gosec
}
_, err = h1.Write(k1[:])
if err != nil {
return []byte{}, err
}
_, err = h1.Write(packet)
if err != nil {
return []byte{}, err
}
d1 := h1.Sum(nil)
_, err = h2.Write(k2[:])
if err != nil {
return []byte{}, err
}
_, err = h2.Write(d1)
if err != nil {
return []byte{}, err
}
return h2.Sum(nil)[:12], nil
}
func (sp *UsmSecurityParameters) authenticate(packet []byte) error {
var msgDigest []byte
var err error
if msgDigest, err = sp.calcPacketDigest(packet); err != nil {
return err
}
idx := bytes.Index(packet, macVarbinds[sp.AuthenticationProtocol])
if idx < 0 {
return fmt.Errorf("unable to locate the position in packet to write authentication key")
}
copy(packet[idx+2:idx+len(macVarbinds[sp.AuthenticationProtocol])], msgDigest)
return nil
}
// determine whether a message is authentic
func (sp *UsmSecurityParameters) isAuthentic(packetBytes []byte, packet *SnmpPacket) (bool, error) {
var msgDigest []byte
var packetSecParams *UsmSecurityParameters
var err error
if packetSecParams, err = castUsmSecParams(packet.SecurityParameters); err != nil {
return false, err
}
// Verify the username
if packetSecParams.UserName != sp.UserName {
return false, nil
}
// TODO: investigate call chain to determine if this is really the best spot for this
if msgDigest, err = calcPacketDigest(packetBytes, packetSecParams); err != nil {
return false, err
}
// Check the message signature against the computed digest
signature := []byte(packetSecParams.AuthenticationParameters)
return subtle.ConstantTimeCompare(msgDigest, signature) == 1, nil
}
func (sp *UsmSecurityParameters) encryptPacket(scopedPdu []byte) ([]byte, error) {
var b []byte
switch sp.PrivacyProtocol {
case AES, AES192, AES256, AES192C, AES256C:
var iv [16]byte
binary.BigEndian.PutUint32(iv[:], sp.AuthoritativeEngineBoots)
binary.BigEndian.PutUint32(iv[4:], sp.AuthoritativeEngineTime)
copy(iv[8:], sp.PrivacyParameters)
// aes.NewCipher(sp.PrivacyKey[:16]) changed to aes.NewCipher(sp.PrivacyKey)
block, err := aes.NewCipher(sp.PrivacyKey)
if err != nil {
return nil, err
}
//nolint:staticcheck // RFC3826 Section 3.1.1.1 specifies CFB-128 mode for AES
stream := cipher.NewCFBEncrypter(block, iv[:])
ciphertext := make([]byte, len(scopedPdu))
stream.XORKeyStream(ciphertext, scopedPdu)
pduLen, err := marshalLength(len(ciphertext))
if err != nil {
return nil, err
}
b = append([]byte{byte(OctetString)}, pduLen...)
scopedPdu = append(b, ciphertext...) //nolint:gocritic
case DES:
preiv := sp.PrivacyKey[8:]
var iv [8]byte
for i := range len(iv) {
iv[i] = preiv[i] ^ sp.PrivacyParameters[i] //nolint:gosec
}
block, err := des.NewCipher(sp.PrivacyKey[:8]) //nolint:gosec
if err != nil {
return nil, err
}
mode := cipher.NewCBCEncrypter(block, iv[:])
pad := make([]byte, des.BlockSize-len(scopedPdu)%des.BlockSize)
scopedPdu = append(scopedPdu, pad...)
ciphertext := make([]byte, len(scopedPdu))
mode.CryptBlocks(ciphertext, scopedPdu)
pduLen, err := marshalLength(len(ciphertext))
if err != nil {
return nil, err
}
b = append([]byte{byte(OctetString)}, pduLen...)
scopedPdu = append(b, ciphertext...) //nolint:gocritic
}
return scopedPdu, nil
}
func (sp *UsmSecurityParameters) decryptPacket(packet []byte, cursor int) ([]byte, error) {
_, cursorTmp, err := parseLength(packet[cursor:])
if err != nil {
return nil, err
}
cursorTmp += cursor
if cursorTmp > len(packet) {
return nil, errors.New("error decrypting ScopedPDU: truncated packet")
}
switch sp.PrivacyProtocol {
case AES, AES192, AES256, AES192C, AES256C:
var iv [16]byte
binary.BigEndian.PutUint32(iv[:], sp.AuthoritativeEngineBoots)
binary.BigEndian.PutUint32(iv[4:], sp.AuthoritativeEngineTime)
copy(iv[8:], sp.PrivacyParameters)
block, err := aes.NewCipher(sp.PrivacyKey)
if err != nil {
return nil, err
}
//nolint:staticcheck // RFC3826 Section 3.1.1.1 specifies CFB-128 mode for AES
stream := cipher.NewCFBDecrypter(block, iv[:])
plaintext := make([]byte, len(packet[cursorTmp:]))
stream.XORKeyStream(plaintext, packet[cursorTmp:])
copy(packet[cursor:], plaintext)
packet = packet[:cursor+len(plaintext)]
case DES:
if len(packet[cursorTmp:])%des.BlockSize != 0 {
return nil, errors.New("error decrypting ScopedPDU: not multiple of des block size")
}
preiv := sp.PrivacyKey[8:]
var iv [8]byte
for i := range len(iv) {
iv[i] = preiv[i] ^ sp.PrivacyParameters[i] //nolint:gosec
}
block, err := des.NewCipher(sp.PrivacyKey[:8]) //nolint:gosec
if err != nil {
return nil, err
}
mode := cipher.NewCBCDecrypter(block, iv[:])
plaintext := make([]byte, len(packet[cursorTmp:]))
mode.CryptBlocks(plaintext, packet[cursorTmp:])
copy(packet[cursor:], plaintext)
// truncate packet to remove extra space caused by the
// octetstring/length header that was just replaced
packet = packet[:cursor+len(plaintext)]
}
return packet, nil
}
// marshal a snmp version 3 security parameters field for the User Security Model
func (sp *UsmSecurityParameters) marshal(flags SnmpV3MsgFlags) ([]byte, error) {
var buf bytes.Buffer
var err error
// msgAuthoritativeEngineID
buf.Write([]byte{byte(OctetString), byte(len(sp.AuthoritativeEngineID))})
buf.WriteString(sp.AuthoritativeEngineID)
// msgAuthoritativeEngineBoots
msgAuthoritativeEngineBoots, err := marshalUint32(sp.AuthoritativeEngineBoots)
if err != nil {
return nil, err
}
buf.Write([]byte{byte(Integer), byte(len(msgAuthoritativeEngineBoots))})
buf.Write(msgAuthoritativeEngineBoots)
// msgAuthoritativeEngineTime
msgAuthoritativeEngineTime, err := marshalUint32(sp.AuthoritativeEngineTime)
if err != nil {
return nil, err
}
buf.Write([]byte{byte(Integer), byte(len(msgAuthoritativeEngineTime))})
buf.Write(msgAuthoritativeEngineTime)
// msgUserName
buf.Write([]byte{byte(OctetString), byte(len(sp.UserName))})
buf.WriteString(sp.UserName)
// msgAuthenticationParameters
if flags&AuthNoPriv > 0 {
buf.Write(macVarbinds[sp.AuthenticationProtocol])
} else {
buf.Write([]byte{byte(OctetString), 0})
}
// msgPrivacyParameters
if flags&AuthPriv > AuthNoPriv {
privlen, err2 := marshalLength(len(sp.PrivacyParameters))
if err2 != nil {
return nil, err2
}
buf.Write([]byte{byte(OctetString)})
buf.Write(privlen)
buf.Write(sp.PrivacyParameters)
} else {
buf.Write([]byte{byte(OctetString), 0})
}
// wrap security parameters in a sequence
paramLen, err := marshalLength(buf.Len())
if err != nil {
return nil, err
}
tmpseq := append([]byte{byte(Sequence)}, paramLen...)
tmpseq = append(tmpseq, buf.Bytes()...)
return tmpseq, nil
}
func (sp *UsmSecurityParameters) unmarshal(flags SnmpV3MsgFlags, packet []byte, cursor int) (int, error) {
var err error
if cursor >= len(packet) {
return 0, errors.New("error parsing SNMPV3 User Security Model parameters: end of packet")
}
if PDUType(packet[cursor]) != Sequence {
return 0, errors.New("error parsing SNMPV3 User Security Model parameters")
}
_, cursorTmp, err := parseLength(packet[cursor:])
if err != nil {
return 0, err
}
cursor += cursorTmp
if cursorTmp > len(packet) {
return 0, errors.New("error parsing SNMPV3 User Security Model parameters: truncated packet")
}
rawMsgAuthoritativeEngineID, count, err := parseRawField(sp.Logger, packet[cursor:], "msgAuthoritativeEngineID")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 User Security Model msgAuthoritativeEngineID: %w", err)
}
cursor += count
if AuthoritativeEngineID, ok := rawMsgAuthoritativeEngineID.(string); ok {
if sp.AuthoritativeEngineID != AuthoritativeEngineID {
sp.AuthoritativeEngineID = AuthoritativeEngineID
sp.SecretKey = nil
sp.PrivacyKey = nil
sp.Logger.Printf("Parsed authoritativeEngineID %0x", []byte(AuthoritativeEngineID))
err = sp.initSecurityKeysNoLock()
if err != nil {
return 0, err
}
}
}
rawMsgAuthoritativeEngineBoots, count, err := parseRawField(sp.Logger, packet[cursor:], "msgAuthoritativeEngineBoots")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 User Security Model msgAuthoritativeEngineBoots: %w", err)
}
cursor += count
if AuthoritativeEngineBoots, ok := rawMsgAuthoritativeEngineBoots.(int); ok {
sp.AuthoritativeEngineBoots = uint32(AuthoritativeEngineBoots) //nolint:gosec
sp.Logger.Printf("Parsed authoritativeEngineBoots %d", AuthoritativeEngineBoots)
}
rawMsgAuthoritativeEngineTime, count, err := parseRawField(sp.Logger, packet[cursor:], "msgAuthoritativeEngineTime")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 User Security Model msgAuthoritativeEngineTime: %w", err)
}
cursor += count
if AuthoritativeEngineTime, ok := rawMsgAuthoritativeEngineTime.(int); ok {
sp.AuthoritativeEngineTime = uint32(AuthoritativeEngineTime) //nolint:gosec
sp.Logger.Printf("Parsed authoritativeEngineTime %d", AuthoritativeEngineTime)
}
rawMsgUserName, count, err := parseRawField(sp.Logger, packet[cursor:], "msgUserName")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 User Security Model msgUserName: %w", err)
}
cursor += count
if msgUserName, ok := rawMsgUserName.(string); ok {
sp.UserName = msgUserName
sp.Logger.Printf("Parsed userName %s", msgUserName)
}
rawMsgAuthParameters, count, err := parseRawField(sp.Logger, packet[cursor:], "msgAuthenticationParameters")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 User Security Model msgAuthenticationParameters: %w", err)
}
if msgAuthenticationParameters, ok := rawMsgAuthParameters.(string); ok {
sp.AuthenticationParameters = msgAuthenticationParameters
sp.Logger.Printf("Parsed authenticationParameters %s", msgAuthenticationParameters)
}
// blank msgAuthenticationParameters to prepare for authentication check later
if flags&AuthNoPriv > 0 {
// In case if the authentication protocol is not configured or set to NoAuth, then the packet cannot
// be processed further
if sp.AuthenticationProtocol <= NoAuth {
return 0, errors.New("error parsing SNMPv3 User Security Model: authentication parameters are not configured to parse incoming authenticated message")
}
copy(packet[cursor+2:cursor+len(macVarbinds[sp.AuthenticationProtocol])], macVarbinds[sp.AuthenticationProtocol][2:])
}
cursor += count
rawMsgPrivacyParameters, count, err := parseRawField(sp.Logger, packet[cursor:], "msgPrivacyParameters")
if err != nil {
return 0, fmt.Errorf("error parsing SNMPV3 User Security Model msgPrivacyParameters: %w", err)
}
cursor += count
if msgPrivacyParameters, ok := rawMsgPrivacyParameters.(string); ok {
sp.PrivacyParameters = []byte(msgPrivacyParameters)
sp.Logger.Printf("Parsed privacyParameters %s", msgPrivacyParameters)
if flags&AuthPriv >= AuthPriv {
if sp.PrivacyProtocol <= NoPriv {
return 0, errors.New("error parsing SNMPv3 User Security Model: privacy parameters are not configured to parse incoming encrypted message")
}
}
}
return cursor, nil
}
// Copyright 2012 The GoSNMP Authors. All rights reserved. Use of this
// source code is governed by a BSD-style license that can be found in the
// LICENSE file.
package gosnmp
import (
"fmt"
"strings"
)
func (x *GoSNMP) walk(getRequestType PDUType, rootOid string, walkFn WalkFunc) error {
// If no rootOid is provided, fall back to the 'internet' subtree (.1.3.6.1).
// This ensures visibility of both standard (e.g. MIB-2) and vendor-specific branches.
// It also guarantees the OID is valid for BER encoding:
// - RFC 2578 §7.1.3: OIDs must have at least two sub-identifiers
// - X.690 §8.19: the first two arcs are encoded as (40 * arc1 + arc2)
if rootOid == "" || rootOid == "." {
// IANA 'internet' subtree under ISO OID structure per X.660.
// See https://oidref.com/1.3.6.1
rootOid = ".1.3.6.1"
}
if !strings.HasPrefix(rootOid, ".") {
rootOid = string(".") + rootOid
}
oid := rootOid
requests := 0
maxReps := x.MaxRepetitions
if maxReps == 0 {
maxReps = defaultMaxRepetitions
}
// AppOpt 'c: do not check returned OIDs are increasing'
checkIncreasing := true
if x.AppOpts != nil {
if _, ok := x.AppOpts["c"]; ok {
if getRequestType == GetBulkRequest || getRequestType == GetNextRequest {
checkIncreasing = false
}
}
}
RequestLoop:
for {
requests++
var response *SnmpPacket
var err error
switch getRequestType {
case GetBulkRequest:
response, err = x.GetBulk([]string{oid}, 0, maxReps)
case GetNextRequest:
response, err = x.GetNext([]string{oid})
case GetRequest:
response, err = x.Get([]string{oid})
default:
response, err = nil, fmt.Errorf("unsupported request type: %d", getRequestType)
}
if err != nil {
return err
}
if len(response.Variables) == 0 {
break RequestLoop
}
switch response.Error {
case TooBig:
x.Logger.Print("Walk terminated with TooBig")
break RequestLoop
case NoSuchName:
x.Logger.Print("Walk terminated with NoSuchName")
break RequestLoop
case BadValue:
x.Logger.Print("Walk terminated with BadValue")
break RequestLoop
case ReadOnly:
x.Logger.Print("Walk terminated with ReadOnly")
break RequestLoop
case GenErr:
x.Logger.Print("Walk terminated with GenErr")
break RequestLoop
case NoAccess:
x.Logger.Print("Walk terminated with NoAccess")
break RequestLoop
case WrongType:
x.Logger.Print("Walk terminated with WrongType")
break RequestLoop
case WrongLength:
x.Logger.Print("Walk terminated with WrongLength")
break RequestLoop
case WrongEncoding:
x.Logger.Print("Walk terminated with WrongEncoding")
break RequestLoop
case WrongValue:
x.Logger.Print("Walk terminated with WrongValue")
break RequestLoop
case NoCreation:
x.Logger.Print("Walk terminated with NoCreation")
break RequestLoop
case InconsistentValue:
x.Logger.Print("Walk terminated with InconsistentValue")
break RequestLoop
case ResourceUnavailable:
x.Logger.Print("Walk terminated with ResourceUnavailable")
break RequestLoop
case CommitFailed:
x.Logger.Print("Walk terminated with CommitFailed")
break RequestLoop
case UndoFailed:
x.Logger.Print("Walk terminated with UndoFailed")
break RequestLoop
case AuthorizationError:
x.Logger.Print("Walk terminated with AuthorizationError")
break RequestLoop
case NotWritable:
x.Logger.Print("Walk terminated with NotWritable")
break RequestLoop
case InconsistentName:
x.Logger.Print("Walk terminated with InconsistentName")
break RequestLoop
case NoError:
x.Logger.Print("Walk completed with NoError")
}
for i, pdu := range response.Variables {
if pdu.Type == EndOfMibView || pdu.Type == NoSuchObject || pdu.Type == NoSuchInstance {
x.Logger.Printf("BulkWalk terminated with type 0x%x", pdu.Type)
break RequestLoop
}
if !strings.HasPrefix(pdu.Name, rootOid+".") {
// Not in the requested root range.
// if this is the first request, and the first variable in that request
// and this condition is triggered - the first result is out of range
// need to perform a regular get request
// this request has been too narrowly defined to be found with a getNext
// Issue #78 #93
if requests == 1 && i == 0 {
getRequestType = GetRequest
continue RequestLoop
} else if pdu.Name == rootOid && pdu.Type != NoSuchInstance {
// Call walk function if the pdu instance is found
// considering that the rootOid is a leafOid
if err := walkFn(pdu); err != nil {
return err
}
}
break RequestLoop
}
if checkIncreasing && pdu.Name == oid {
return fmt.Errorf("OID not increasing: %s", pdu.Name)
}
// Report our pdu
if err := walkFn(pdu); err != nil {
return err
}
}
// Save last oid for next request
oid = response.Variables[len(response.Variables)-1].Name
}
x.Logger.Printf("BulkWalk completed in %d requests", requests)
return nil
}
func (x *GoSNMP) walkAll(getRequestType PDUType, rootOid string) (results []SnmpPDU, err error) {
err = x.walk(getRequestType, rootOid, func(dataUnit SnmpPDU) error {
results = append(results, dataUnit)
return nil
})
return results, err
}