package mgodbc /* #cgo darwin LDFLAGS: -lodbc #cgo freebsd LDFLAGS: -lodbc #cgo linux LDFLAGS: -lodbc #cgo windows LDFLAGS: -lodbc32 #ifdef __MINGW32__ #include #endif #include #include #include */ import "C" import ( "errors" "database/sql" "database/sql/driver" "fmt" "io" "strings" "time" "unicode/utf16" "unsafe" ) // init // func init() { d, err := NewDriver() if err == nil && d != nil { sql.Register("mgodbc", d) } } func RegisterMGODBC() {} // utility // func success(ret C.SQLRETURN) bool { return ret == C.SQL_SUCCESS || ret == C.SQL_SUCCESS_WITH_INFO } func utf16ToString(utf16Data []uint16) string { return string(utf16.Decode(utf16Data)) } func stringToUTF16(data string) []uint16 { return utf16.Encode([]rune(data)) } // ODBC Error // type statusRecord struct { state string nativeError int message string } func (sr *statusRecord) toString() string { return fmt.Sprintf("{%s} %s", sr.state, sr.message) } type odbcError struct { statusRecords []statusRecord } func (e *odbcError) Error() string { statusStrings := make([]string, len(e.statusRecords)) for i, sr := range e.statusRecords { statusStrings[i] = sr.toString() } return strings.Join(statusStrings, "\n") } func newError(sqlHandle interface{}) error { // Figure out the handle type var handleType C.SQLSMALLINT var handle C.SQLHANDLE switch sqlHandle.(type) { case C.SQLHENV: handleType = C.SQL_HANDLE_ENV handle = C.SQLHANDLE(sqlHandle.(C.SQLHENV)) case C.SQLHDBC: handleType = C.SQL_HANDLE_DBC handle = C.SQLHANDLE(sqlHandle.(C.SQLHDBC)) case C.SQLHSTMT: handleType = C.SQL_HANDLE_STMT handle = C.SQLHANDLE(sqlHandle.(C.SQLHSTMT)) default: return errors.New("unknown odbc handle type") } // Get the number of diagnostic records var diagRecordCount C.SQLULEN ret := C.SQLGetDiagField( handleType, handle, 0, C.SQL_DIAG_NUMBER, C.SQLPOINTER(&diagRecordCount), 4, nil, ) if !success(ret) { return errors.New("failed to retrieve diagnostic header information") } // Query each of the diagnostic records recordCount := int(diagRecordCount) statusRecords := make([]statusRecord, 0) for i := 0; i < recordCount; i++ { // Find the needed size for the message buffer var messageLen C.SQLSMALLINT ret = C.SQLGetDiagRec( handleType, handle, C.SQLSMALLINT(i+1), nil, nil, nil, 0, &messageLen, ) if !success(ret) { continue } // Get the diagnostic record values var state [6]uint16 var nativeError C.SQLINTEGER messageBuf := make([]uint16, int(messageLen)+1) ret = C.SQLGetDiagRecW( handleType, handle, C.SQLSMALLINT(i+1), (*C.SQLWCHAR)(&state[0]), &nativeError, (*C.SQLWCHAR)(&messageBuf[0]), messageLen+1, nil, ) if !success(ret) { continue } sr := statusRecord{} sr.state = utf16ToString(state[:5]) sr.nativeError = int(nativeError) sr.message = utf16ToString(messageBuf[:messageLen]) statusRecords = append(statusRecords, sr) } return &odbcError{statusRecords: statusRecords} } // Driver // type Driver struct { envHandle C.SQLHENV } func NewDriver() (driver.Driver, error) { d := &Driver{} // Allocate the environment handle for the driver ret := C.SQLAllocHandle( C.SQL_HANDLE_ENV, nil, (*C.SQLHANDLE)(&d.envHandle), ) if !success(ret) { return nil, errors.New("failed to allocate environment handle") } // Set the environment handle to use ODBCv3 ret = C.SQLSetEnvAttr( C.SQLHENV(d.envHandle), C.SQL_ATTR_ODBC_VERSION, C.SQLPOINTER(uintptr(C.SQL_OV_ODBC3)), 0, ) if !success(ret) { err := newError(d.envHandle) C.SQLFreeHandle(C.SQL_HANDLE_ENV, C.SQLHANDLE(d.envHandle)) return nil, err } return d, nil } func (d *Driver) Open(name string) (driver.Conn, error) { if d.envHandle == nil { return nil, errors.New("driver has been closed") } c := &conn{} // Allocate the connection handle ret := C.SQLAllocHandle( C.SQL_HANDLE_DBC, C.SQLHANDLE(d.envHandle), (*C.SQLHANDLE)(&c.dbc), ) if !success(ret) { return nil, newError(d.envHandle) } // Perform the driver connect var outLen C.SQLSMALLINT utf16Name := stringToUTF16(name) ret = C.SQLDriverConnectW( c.dbc, nil, (*C.SQLWCHAR)(&utf16Name[0]), C.SQLSMALLINT(len(utf16Name)), nil, 0, &outLen, C.SQL_DRIVER_NOPROMPT, ) if !success(ret) { err := newError(c.dbc) C.SQLFreeHandle(C.SQL_HANDLE_DBC, C.SQLHANDLE(c.dbc)) return nil, err } return c, nil } func (d *Driver) Close() error { if d.envHandle == nil { return errors.New("environment handle already freed") } // Free the environment handle ret := C.SQLFreeHandle(C.SQL_HANDLE_ENV, C.SQLHANDLE(d.envHandle)) d.envHandle = nil if !success(ret) { return errors.New("failed to free environment handle") } return nil } // Conn // type conn struct { dbc C.SQLHDBC t *tx } func (c *conn) prepareStmt(query string) (C.SQLHSTMT, error) { if c.dbc == nil { return nil, errors.New("connection has been closed") } // Allocate the statement handle var stmt C.SQLHSTMT ret := C.SQLAllocHandle( C.SQL_HANDLE_STMT, C.SQLHANDLE(c.dbc), (*C.SQLHANDLE)(&stmt), ) if !success(ret) { return nil, newError(c.dbc) } // Prepare the statement utf16Query := stringToUTF16(query) ret = C.SQLPrepareW( stmt, (*C.SQLWCHAR)(&utf16Query[0]), C.SQLINTEGER(len(utf16Query)), ) if !success(ret) { err := newError(stmt) C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } return stmt, nil } func (c *conn) Prepare(query string) (driver.Stmt, error) { if c.dbc == nil { return nil, errors.New("connection has been closed") } s := &stmt{ c: c, query: query, } // Prepare the statement stmt, err := c.prepareStmt(s.query) if err != nil { return nil, err } // Get the input count var paramCount C.SQLSMALLINT ret := C.SQLNumParams(stmt, ¶mCount) if success(ret) { s.numInput = int(paramCount) } // Free the statement handle ret = C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) if !success(ret) { return nil, newError(stmt) } return s, nil } func (c *conn) Close() error { if c.dbc == nil { return errors.New("connection already closed") } // Disconnect the connection ret := C.SQLDisconnect(c.dbc) if !success(ret) { err := newError(c.dbc) C.SQLFreeHandle(C.SQL_HANDLE_DBC, C.SQLHANDLE(c.dbc)) c.dbc = nil return err } // Free the connection handle ret = C.SQLFreeHandle(C.SQL_HANDLE_DBC, C.SQLHANDLE(c.dbc)) if !success(ret) { err := newError(c.dbc) c.dbc = nil return err } c.dbc = nil return nil } func (c *conn) Begin() (driver.Tx, error) { if c.t != nil { return nil, errors.New("a transaction has already been started") } // Check for transaction support var txnCapable C.SQLUSMALLINT ret := C.SQLGetInfo( c.dbc, C.SQLUSMALLINT(C.SQL_TXN_CAPABLE), C.SQLPOINTER(&txnCapable), 2, nil, ) if !success(ret) { return nil, newError(c.dbc) } if txnCapable == C.SQL_TC_NONE { return nil, errors.New("transactions are not supported by this ODBC driver") } // Turn autocommit off ret = C.SQLSetConnectAttr( c.dbc, C.SQL_ATTR_AUTOCOMMIT, C.SQLPOINTER(uintptr(C.SQL_AUTOCOMMIT_OFF)), C.SQL_IS_UINTEGER, ) if !success(ret) { return nil, newError(c.dbc) } return &tx{c: c}, nil } // Stmt // type stmt struct { c *conn query string numInput int } func bindStmt(s C.SQLHSTMT, args []interface{}) error { for i, arg := range args { var valueType C.SQLSMALLINT var parameterType C.SQLSMALLINT var columnSize C.SQLULEN var decimalDigits C.SQLSMALLINT var parameterValuePtr C.SQLPOINTER var bufferLength C.SQLLEN var intPtr C.SQLLEN switch arg.(type) { case nil: valueType = C.SQL_C_DEFAULT parameterType = C.SQL_WCHAR columnSize = 1 intPtr = C.SQL_NULL_DATA case bool: var bValue [1]byte if arg.(bool) { bValue[0] = 1 } else { bValue[0] = 0 } valueType = C.SQL_C_BIT parameterType = C.SQL_BIT parameterValuePtr = C.SQLPOINTER(&bValue[0]) bufferLength = 1 case int64: llValue := C.longlong(arg.(int64)) valueType = C.SQL_C_SBIGINT parameterType = C.SQL_BIGINT parameterValuePtr = C.SQLPOINTER(&llValue) bufferLength = 8 case float64: dValue := C.double(arg.(float64)) valueType = C.SQL_C_DOUBLE parameterType = C.SQL_DOUBLE parameterValuePtr = C.SQLPOINTER(&dValue) bufferLength = 8 case string: stringValue := stringToUTF16(arg.(string)) valueType = C.SQL_C_WCHAR parameterType = C.SQL_WVARCHAR parameterValuePtr = C.SQLPOINTER(&stringValue[0]) columnSize = C.SQLULEN(len(stringValue) * 2) bufferLength = C.SQLLEN(len(stringValue) * 2) intPtr = C.SQLLEN(len(stringValue) * 2) case []byte: byteValue := arg.([]byte) valueType = C.SQL_C_BINARY parameterType = C.SQL_BINARY parameterValuePtr = C.SQLPOINTER(&byteValue[0]) columnSize = C.SQLULEN(len(byteValue) * 2) bufferLength = C.SQLLEN(len(byteValue) * 2) intPtr = C.SQLLEN(len(byteValue) * 2) case time.Time: timeValue := arg.(time.Time) var timestampValue C.SQL_TIMESTAMP_STRUCT timestampValue.year = C.SQLSMALLINT(timeValue.Year()) timestampValue.month = C.SQLUSMALLINT(timeValue.Month()) timestampValue.day = C.SQLUSMALLINT(timeValue.Day()) timestampValue.hour = C.SQLUSMALLINT(timeValue.Hour()) timestampValue.minute = C.SQLUSMALLINT(timeValue.Minute()) timestampValue.second = C.SQLUSMALLINT(timeValue.Second()) timestampValue.fraction = C.SQLUINTEGER(timeValue.Nanosecond()) valueType = C.SQL_C_TYPE_TIMESTAMP parameterType = C.SQL_TYPE_TIMESTAMP parameterValuePtr = C.SQLPOINTER(×tampValue) bufferLength = C.SQLLEN(unsafe.Sizeof(timestampValue)) default: return errors.New("invalid parameter type to bind to") } ret := C.SQLBindParameter( s, C.SQLUSMALLINT(i+1), C.SQL_PARAM_INPUT, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, &intPtr, ) if !success(ret) { return newError(s) } } return nil } func (s *stmt) Close() error { if s.c == nil { return errors.New("statement has already been closed") } s.c = nil return nil } func (s *stmt) NumInput() int { return s.numInput } func (s *stmt) Exec(args []interface{}) (driver.Result, error) { if s.c == nil { return nil, errors.New("statement has been closed") } // Prepare the statement stmt, err := s.c.prepareStmt(s.query) if err != nil { return nil, err } // Bind the statement err = bindStmt(stmt, args) if err != nil { C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } // Execute the statement ret := C.SQLExecute(stmt) if !success(ret) { err := newError(stmt) C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } // Get the number of rows affected r := &result{} var rowCount C.SQLLEN ret = C.SQLRowCount(stmt, &rowCount) if !success(ret) { err := newError(stmt) C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } // Free the statement ret = C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) if !success(ret) { err := newError(stmt) return nil, err } r.rowsAffected = int64(rowCount) return r, nil } func (s *stmt) Query(args []interface{}) (driver.Rows, error) { if s.c == nil { return nil, errors.New("statement has been closed") } // Prepare the statement stmt, err := s.c.prepareStmt(s.query) if err != nil { return nil, err } // Bind the statement err = bindStmt(stmt, args) if err != nil { C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } // Execute the statement ret := C.SQLExecute(stmt) if !success(ret) { err := newError(stmt) C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } // Get the column names var numResultCols C.SQLSMALLINT ret = C.SQLNumResultCols(stmt, &numResultCols) if !success(ret) { err := newError(stmt) C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } columns := make([]string, int(numResultCols)) for i, _ := range columns { // Get the length of the column name var nameLen C.SQLSMALLINT ret = C.SQLColAttributeW( stmt, C.SQLUSMALLINT(i+1), C.SQL_DESC_NAME, nil, 0, &nameLen, nil, ) if !success(ret) { err := newError(stmt) C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } // If the name length is 0, skip getting the name (the default is empty anyway) columnNameLen := int(nameLen) / 2 if columnNameLen == 0 { continue } // Get the column name columnName := make([]uint16, columnNameLen+1) ret = C.SQLColAttributeW( stmt, C.SQLUSMALLINT(i+1), C.SQL_DESC_NAME, C.SQLPOINTER(&columnName[0]), C.SQLSMALLINT((columnNameLen+1)*2), nil, nil, ) if !success(ret) { err := newError(stmt) C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(stmt)) return nil, err } columns[i] = utf16ToString(columnName[:columnNameLen]) } rows := &rows{ s: stmt, columns: columns, } return rows, nil } // Result // type result struct { rowsAffected int64 } func (*result) LastInsertId() (int64, error) { return 0, errors.New("not supported") } func (r *result) RowsAffected() (int64, error) { return r.rowsAffected, nil } // Rows // type rows struct { s C.SQLHSTMT columns []string } func (r *rows) Columns() []string { return r.columns } func (r *rows) Close() error { if r.s == nil { return errors.New("statement has already been closed") } ret := C.SQLFreeHandle(C.SQL_HANDLE_STMT, C.SQLHANDLE(r.s)) if !success(ret) { err := newError(r.s) return err } r.s = nil return nil } func (r *rows) Next(dest []interface{}) error { if r.s == nil { return errors.New("statement has been closed") } // Special check in case there was no result set generated var columnCount C.SQLSMALLINT ret := C.SQLNumResultCols(r.s, &columnCount) if !success(ret) { return newError(r.s) } if columnCount == 0 { return io.EOF } // Fetch the next row ret = C.SQLFetch(r.s) if ret == C.SQL_NO_DATA { return io.EOF } else if !success(ret) { return newError(r.s) } for i, _ := range dest { // Get the type of the column var colType C.SQLLEN ret = C.SQLColAttribute( r.s, C.SQLUSMALLINT(i+1), C.SQL_DESC_CONCISE_TYPE, nil, 0, nil, &colType, ) if !success(ret) { return newError(r.s) } // Query the data from the column var dummy [1]byte var indPtr C.SQLLEN switch int(colType) { case C.SQL_BIT: var bitValue byte ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_BIT, C.SQLPOINTER(&bitValue), 0, &indPtr, ) if !success(ret) { return newError(r.s) } if indPtr == C.SQL_NULL_DATA { dest[i] = nil } else { if bitValue == 0 { dest[i] = false } else { dest[i] = true } } case C.SQL_TINYINT, C.SQL_SMALLINT, C.SQL_INTEGER, C.SQL_BIGINT: var intValue C.longlong ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_SBIGINT, C.SQLPOINTER(&intValue), 0, &indPtr, ) if !success(ret) { return newError(r.s) } if indPtr == C.SQL_NULL_DATA { dest[i] = nil } else { dest[i] = int64(intValue) } case C.SQL_REAL, C.SQL_FLOAT, C.SQL_DOUBLE: var doubleValue C.double ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_DOUBLE, C.SQLPOINTER(&doubleValue), 0, &indPtr, ) if !success(ret) { return newError(r.s) } if indPtr == C.SQL_NULL_DATA { dest[i] = nil } else { dest[i] = float64(doubleValue) } case C.SQL_CHAR, C.SQL_VARCHAR, C.SQL_LONGVARCHAR, C.SQL_WCHAR, C.SQL_WVARCHAR, C.SQL_WLONGVARCHAR: ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_WCHAR, C.SQLPOINTER(&dummy[0]), 0, &indPtr, ) if !success(ret) { return newError(r.s) } if indPtr == C.SQL_NULL_DATA { dest[i] = nil } else { strLen := int(indPtr) / 2 strBuf := make([]uint16, strLen+1) ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_WCHAR, C.SQLPOINTER(&strBuf[0]), C.SQLLEN((strLen+1)*2), &indPtr, ) if !success(ret) { return newError(r.s) } dest[i] = []byte(utf16ToString(strBuf[:strLen])) } case C.SQL_BINARY, C.SQL_VARBINARY, C.SQL_LONGVARBINARY: ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_BINARY, C.SQLPOINTER(&dummy[0]), 0, &indPtr, ) if !success(ret) { return newError(r.s) } if indPtr == C.SQL_NULL_DATA { dest[i] = nil } else { binaryLen := int(indPtr) binaryBuf := make([]byte, binaryLen) ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_BINARY, C.SQLPOINTER(&binaryBuf[0]), C.SQLLEN(binaryLen+1), &indPtr, ) if !success(ret) { return newError(r.s) } dest[i] = binaryBuf } case C.SQL_TYPE_TIMESTAMP: var timestampValue C.SQL_TIMESTAMP_STRUCT ret = C.SQLGetData( r.s, C.SQLUSMALLINT(i+1), C.SQL_C_TYPE_TIMESTAMP, C.SQLPOINTER(×tampValue), C.SQLLEN(unsafe.Sizeof(timestampValue)), &indPtr, ) if !success(ret) { return newError(r.s) } if indPtr == C.SQL_NULL_DATA { dest[i] = nil } else { dest[i] = time.Date( int(timestampValue.year), time.Month(timestampValue.month), int(timestampValue.day), int(timestampValue.hour), int(timestampValue.minute), int(timestampValue.second), int(timestampValue.fraction), time.UTC, ) } default: dest[i] = nil } } return nil } func (*rows) LastInsertId() (int64, error) { return 0, errors.New("not supported") } func (r *rows) RowsAffected() (int64, error) { if r.s == nil { return 0, errors.New("statement has been closed") } // Get the number of rows affected var rowCount C.SQLLEN ret := C.SQLRowCount(r.s, &rowCount) if !success(ret) { return 0, newError(r.s) } return int64(rowCount), nil } // Tx // type tx struct { c *conn } func (t *tx) Commit() error { if t.c == nil { return errors.New("transaction has already ended") } // Commit the transaction ret := C.SQLEndTran( C.SQL_HANDLE_DBC, C.SQLHANDLE(t.c.dbc), C.SQL_COMMIT, ) if !success(ret) { return newError(t.c.dbc) } // Turn autocommit back on ret = C.SQLSetConnectAttr( t.c.dbc, C.SQL_ATTR_AUTOCOMMIT, C.SQLPOINTER(uintptr(C.SQL_AUTOCOMMIT_ON)), 0, ) if !success(ret) { err := newError(t.c.dbc) t.c.t = nil t.c = nil return err } t.c.t = nil t.c = nil return nil } func (t *tx) Rollback() error { if t.c == nil { return errors.New("transaction has already ended") } // Rollback the transaction ret := C.SQLEndTran( C.SQL_HANDLE_DBC, C.SQLHANDLE(t.c.dbc), C.SQL_ROLLBACK, ) if !success(ret) { return newError(t.c.dbc) } // Turn autocommit back on ret = C.SQLSetConnectAttr( t.c.dbc, C.SQL_ATTR_AUTOCOMMIT, C.SQLPOINTER(uintptr(C.SQL_AUTOCOMMIT_ON)), 0, ) if !success(ret) { err := newError(t.c.dbc) t.c.t = nil t.c = nil return err } t.c.t = nil t.c = nil return nil }