Created
April 5, 2025 07:46
-
-
Save cjbrigato/da1f928677c69a5f1b31f03b13556d94 to your computer and use it in GitHub Desktop.
Golang nbd-server in < 275 LOC (default export, byte slice "backend")
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/binary" | |
"errors" | |
"flag" | |
"fmt" | |
"io" | |
"log" | |
"net" | |
"sync" | |
"syscall" | |
) | |
const ( | |
NBD_MAGIC = 0x4e42444d41474943 | |
NBD_OPTS_MAGIC = 0x49484156454F5054 | |
NBD_REQUEST_MAGIC = 0x25609513 | |
NBD_SIMPLE_REPLY_MAGIC = 0x67446698 | |
NBD_NEGOTIATION_REPLY_MAGIC = 0x0003e889045565a9 | |
NBD_FLAG_HAS_FLAGS = (1 << 0) | |
NBD_FLAG_READ_ONLY = (1 << 1) // Unused but kept for context if needed later | |
NBD_FLAG_SEND_FLUSH = (1 << 2) | |
NBD_FLAG_SEND_FUA = (1 << 3) | |
NBD_FLAG_ROTATIONAL = (1 << 4) // Unused | |
NBD_FLAG_SEND_TRIM = (1 << 5) | |
NBD_FLAG_SEND_WRITE_ZEROES = (1 << 6) | |
NBD_FLAG_CAN_MULTI_CONN = (1 << 8) // Unused | |
NBD_FLAG_FIXED_NEWSTYLE = (1 << 1) | |
NBD_FLAG_C_FIXED_NEWSTYLE = (1 << 1) | |
NBD_FLAG_C_NO_ZEROES = (1 << 2) // Unused | |
NBD_CMD_READ = 0 | |
NBD_CMD_WRITE = 1 | |
NBD_CMD_DISC = 2 | |
NBD_CMD_FLUSH = 3 | |
NBD_CMD_TRIM = 4 | |
NBD_CMD_WRITE_ZEROES = 6 | |
NBD_OPT_EXPORT_NAME = 1 | |
NBD_OPT_ABORT = 2 | |
NBD_OPT_LIST = 3 | |
NBD_OPT_STARTTLS = 5 // Unused | |
NBD_OPT_INFO = 6 | |
NBD_OPT_GO = 7 | |
NBD_OPT_STRUCTURED_REPLY = 8 // Unused | |
NBD_OPT_LIST_META_CONTEXT = 9 // Unused | |
NBD_OPT_SET_META_CONTEXT = 10 // Unused | |
NBD_REP_ACK = 1 | |
NBD_REP_SERVER = 2 | |
NBD_REP_INFO = 3 | |
NBD_REP_META_CONTEXT = 4 // Unused | |
NBD_REP_ERR_UNSUP = (1 << 31) + 1 | |
NBD_REP_ERR_POLICY = (1 << 31) + 2 // Unused | |
NBD_REP_ERR_INVALID = (1 << 31) + 3 | |
NBD_REP_ERR_PLATFORM = (1 << 31) + 4 // Unused | |
NBD_REP_ERR_TLS_REQD = (1 << 31) + 5 // Unused | |
NBD_REP_ERR_UNKNOWN = (1 << 31) + 6 | |
NBD_REP_ERR_SHUTDOWN = (1 << 31) + 7 // Unused | |
NBD_REP_ERR_BLOCK_SIZE_REQD = (1 << 31) + 8 // Unused | |
NBD_REP_ERR_TOO_BIG = (1 << 31) + 9 // Unused | |
NBD_INFO_EXPORT = 0 | |
NBD_INFO_NAME = 1 // Unused | |
NBD_INFO_DESCRIPTION = 2 // Unused | |
NBD_INFO_BLOCK_SIZE = 3 // Unused | |
) | |
type nbdRequest struct { | |
Magic uint32 | |
Flags uint16 | |
Type uint16 | |
Handle, Offset uint64 | |
Length uint32 | |
} | |
type nbdReply struct { | |
Magic, Error uint32 | |
Handle uint64 | |
} | |
var ( | |
storage []byte | |
size uint64 | |
mutex sync.RWMutex | |
defaultExportName = "" | |
) | |
func main() { | |
sizeMB := flag.Uint64("size", 128, "Size of the block device in MiB") | |
listenAddr := flag.String("addr", ":10809", "Address and port to listen on") | |
flag.Parse() | |
size = *sizeMB * 1024 * 1024 | |
if size == 0 { log.Fatal("Size cannot be zero") } | |
storage = make([]byte, size) | |
log.Printf("NBD server on %s, size %d MiB, export '%s'", *listenAddr, *sizeMB, defaultExportName) | |
listener, err := net.Listen("tcp", *listenAddr) | |
if err != nil { log.Fatalf("Failed to listen on %s: %v", *listenAddr, err) } | |
defer listener.Close() | |
for { | |
conn, err := listener.Accept() | |
if err != nil { log.Printf("Accept failed: %v", err); continue } | |
go handleConnection(conn) | |
} | |
} | |
func handleConnection(conn net.Conn) { | |
remoteAddr := conn.RemoteAddr().String() | |
log.Printf("[%s] Connected", remoteAddr) | |
defer conn.Close() | |
defer log.Printf("[%s] Disconnected", remoteAddr) | |
if err := handleNegotiation(conn, size, defaultExportName, remoteAddr); err != nil { | |
log.Printf("[%s] Handshake failed: %v", remoteAddr, err) | |
return | |
} | |
log.Printf("[%s] Handshake successful", remoteAddr) | |
if err := handleTransmission(conn, remoteAddr); err != nil { | |
if !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) { | |
log.Printf("[%s] Transmission error: %v", remoteAddr, err) | |
} | |
} | |
} | |
func readBE(r io.Reader, data interface{}) error { return binary.Read(r, binary.BigEndian, data) } | |
func writeBE(w io.Writer, data interface{}) error { return binary.Write(w, binary.BigEndian, data) } | |
func sendOptReply(c net.Conn, opt, repType uint32, data []byte) error { | |
dataLen := uint32(0); if data != nil { dataLen = uint32(len(data)) } | |
if writeBE(c, uint64(NBD_NEGOTIATION_REPLY_MAGIC)) != nil { return fmt.Errorf("write rep magic") } | |
if writeBE(c, opt) != nil { return fmt.Errorf("write rep opt") } | |
if writeBE(c, repType) != nil { return fmt.Errorf("write rep type") } | |
if writeBE(c, dataLen) != nil { return fmt.Errorf("write rep len") } | |
if dataLen > 0 { if _, err := c.Write(data); err != nil { return fmt.Errorf("write rep data: %w", err) } } | |
return nil | |
} | |
func sendSimpleOptReply(c net.Conn, opt, repType uint32) error { return sendOptReply(c, opt, repType, nil) } | |
func handleNegotiation(conn net.Conn, expSize uint64, expName, remoteAddr string) error { | |
serverFlags := uint16(NBD_FLAG_HAS_FLAGS | NBD_FLAG_FIXED_NEWSTYLE | NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA | NBD_FLAG_SEND_TRIM | NBD_FLAG_SEND_WRITE_ZEROES) | |
if writeBE(conn, uint64(NBD_MAGIC)) != nil { return fmt.Errorf("write NBD_MAGIC failed") } | |
if writeBE(conn, uint64(NBD_OPTS_MAGIC)) != nil { return fmt.Errorf("write NBD_OPTS_MAGIC failed") } | |
if writeBE(conn, serverFlags) != nil { return fmt.Errorf("write server flags failed") } | |
var clientFlags uint32 | |
if readBE(conn, &clientFlags) != nil { return fmt.Errorf("read client flags failed") } | |
if (serverFlags&NBD_FLAG_FIXED_NEWSTYLE) != 0 && (clientFlags&NBD_FLAG_C_FIXED_NEWSTYLE) == 0 { | |
return fmt.Errorf("client did not agree to fixed newstyle") | |
} | |
transmissionFlags := uint16(NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA | NBD_FLAG_SEND_TRIM | NBD_FLAG_SEND_WRITE_ZEROES) | |
for { | |
var clientOptMagic uint64; var optType, optLen uint32 | |
if readBE(conn, &clientOptMagic) != nil { return fmt.Errorf("read option magic failed") } | |
if clientOptMagic != NBD_OPTS_MAGIC { _ = sendSimpleOptReply(conn, 0, NBD_REP_ERR_INVALID); return fmt.Errorf("invalid option magic 0x%x", clientOptMagic) } | |
if readBE(conn, &optType) != nil { return fmt.Errorf("read option type failed") } | |
if readBE(conn, &optLen) != nil { return fmt.Errorf("read option length failed") } | |
optData := make([]byte, optLen) // TODO: Protect against large optLen | |
if optLen > 0 { if _, err := io.ReadFull(conn, optData); err != nil { return fmt.Errorf("read option data (type %d, len %d) failed: %w", optType, optLen, err) } } | |
switch optType { | |
case NBD_OPT_EXPORT_NAME: | |
reqName := string(optData) | |
if reqName != expName { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_UNKNOWN); continue } | |
if sendSimpleOptReply(conn, optType, NBD_REP_ACK) != nil { return fmt.Errorf("send export ack failed") } | |
if writeBE(conn, expSize) != nil { return fmt.Errorf("write export size failed") } | |
if writeBE(conn, transmissionFlags) != nil { return fmt.Errorf("write transmission flags failed") } | |
if _, err := conn.Write(make([]byte, 124)); err != nil { return fmt.Errorf("write reserved zeros failed: %w", err) } | |
return nil // Handshake complete | |
case NBD_OPT_LIST: | |
nameBytes, descBytes := []byte(expName), []byte("In-memory NBD") | |
replyData := make([]byte, 4+len(nameBytes)+4+len(descBytes)) | |
binary.BigEndian.PutUint32(replyData[0:], uint32(len(nameBytes))); copy(replyData[4:], nameBytes) | |
offset := 4 + len(nameBytes) | |
binary.BigEndian.PutUint32(replyData[offset:], uint32(len(descBytes))); copy(replyData[offset+4:], descBytes) | |
if sendOptReply(conn, optType, NBD_REP_SERVER, replyData) != nil { return fmt.Errorf("send NBD_REP_SERVER failed") } | |
if sendSimpleOptReply(conn, optType, NBD_REP_ACK) != nil { return fmt.Errorf("send list ack failed") } | |
continue | |
case NBD_OPT_ABORT: | |
_ = sendSimpleOptReply(conn, optType, NBD_REP_ACK) | |
return fmt.Errorf("client aborted handshake") | |
case NBD_OPT_INFO, NBD_OPT_GO: | |
if optLen < 6 { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_INVALID); continue } // 4(name_len) + 2(num_info_reqs) | |
nameLen := binary.BigEndian.Uint32(optData[0:4]) | |
if uint32(len(optData)) < 4+nameLen+2 { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_INVALID); continue } | |
reqName := string(optData[4 : 4+nameLen]) | |
// numInfoReqs := binary.BigEndian.Uint16(optData[4+nameLen:]) // Ignored for minimal server | |
if reqName != expName { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_UNKNOWN); continue } | |
replyData := make([]byte, 2+8+2) // type(u16)+size(u64)+flags(u16) | |
binary.BigEndian.PutUint16(replyData[0:], uint16(NBD_INFO_EXPORT)) | |
binary.BigEndian.PutUint64(replyData[2:], expSize) | |
binary.BigEndian.PutUint16(replyData[10:], transmissionFlags) | |
if sendOptReply(conn, optType, NBD_REP_INFO, replyData) != nil { return fmt.Errorf("send NBD_REP_INFO failed") } | |
if sendSimpleOptReply(conn, optType, NBD_REP_ACK) != nil { return fmt.Errorf("send info/go ack failed") } | |
if optType == NBD_OPT_GO { return nil } // Handshake complete | |
continue | |
default: | |
log.Printf("[%s] Unsupported option: Type=%d", remoteAddr, optType) | |
if sendSimpleOptReply(conn, optType, NBD_REP_ERR_UNSUP) != nil { return fmt.Errorf("send NBD_REP_ERR_UNSUP failed") } | |
continue | |
} | |
} | |
} | |
func sendErrorReply(conn net.Conn, handle uint64, ecode syscall.Errno) error { | |
reply := nbdReply{Magic: NBD_SIMPLE_REPLY_MAGIC, Error: uint32(ecode), Handle: handle} | |
return writeBE(conn, reply) | |
} | |
func handleTransmission(conn net.Conn, remoteAddr string) error { | |
req := nbdRequest{} | |
reply := nbdReply{Magic: NBD_SIMPLE_REPLY_MAGIC} | |
for { | |
if err := readBE(conn, &req); err != nil { | |
if errors.Is(err, io.EOF) { return io.EOF } | |
return fmt.Errorf("read request header failed: %w", err) | |
} | |
if req.Magic != NBD_REQUEST_MAGIC { return fmt.Errorf("invalid request magic 0x%x", req.Magic) } // Cannot send reply reliably | |
reply.Handle = req.Handle; reply.Error = 0 // Prepare reply for success case | |
outOfBounds := req.Offset+uint64(req.Length) > size || req.Length == 0 | |
switch req.Type { | |
case NBD_CMD_READ: | |
if outOfBounds { _ = sendErrorReply(conn, req.Handle, syscall.EINVAL); return fmt.Errorf("read out of bounds") } | |
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send read reply header: %w", err) } | |
mutex.RLock() | |
_, err := conn.Write(storage[req.Offset : req.Offset+uint64(req.Length)]) | |
mutex.RUnlock() | |
if err != nil { return fmt.Errorf("send read data: %w", err) } | |
case NBD_CMD_WRITE: | |
if outOfBounds { | |
_, dErr := io.CopyN(io.Discard, conn, int64(req.Length)) // Discard data | |
sErr := sendErrorReply(conn, req.Handle, syscall.EINVAL) | |
if sErr != nil { return fmt.Errorf("send write bounds error reply: %w", sErr) } | |
if dErr != nil && !errors.Is(dErr, io.EOF) { return fmt.Errorf("discard data failed: %w", dErr) } | |
return fmt.Errorf("write out of bounds") | |
} | |
mutex.Lock() | |
_, err := io.ReadFull(conn, storage[req.Offset:req.Offset+uint64(req.Length)]) | |
mutex.Unlock() | |
if err != nil { _ = sendErrorReply(conn, req.Handle, syscall.EIO); return fmt.Errorf("read write data: %w", err) } | |
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send write reply header: %w", err) } | |
case NBD_CMD_DISC: | |
return nil // No reply needed, close gracefully | |
case NBD_CMD_FLUSH: | |
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send flush reply: %w", err) } // In-memory is no-op | |
case NBD_CMD_TRIM, NBD_CMD_WRITE_ZEROES: | |
if outOfBounds { _ = sendErrorReply(conn, req.Handle, syscall.EINVAL); return fmt.Errorf("trim/zero out of bounds") } | |
mutex.Lock() | |
clear(storage[req.Offset : req.Offset+uint64(req.Length)]) // Inlined zeroRange | |
mutex.Unlock() | |
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send trim/zero reply: %w", err) } | |
default: | |
_ = sendErrorReply(conn, req.Handle, syscall.EINVAL) | |
return fmt.Errorf("unknown command type %d", req.Type) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
extreme quick nbd server implementation I use for PoCs.
Exportes only default, byte slice backend. Configurable size and listening adress. That's it. Nothing more. Just enough for the NewStyle protocol.