Skip to content

Instantly share code, notes, and snippets.

@cjbrigato
Created April 5, 2025 07:46
Show Gist options
  • Save cjbrigato/da1f928677c69a5f1b31f03b13556d94 to your computer and use it in GitHub Desktop.
Save cjbrigato/da1f928677c69a5f1b31f03b13556d94 to your computer and use it in GitHub Desktop.
Golang nbd-server in < 275 LOC (default export, byte slice "backend")
package main
import (
"encoding/binary"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"sync"
"syscall"
)
const (
NBD_MAGIC = 0x4e42444d41474943
NBD_OPTS_MAGIC = 0x49484156454F5054
NBD_REQUEST_MAGIC = 0x25609513
NBD_SIMPLE_REPLY_MAGIC = 0x67446698
NBD_NEGOTIATION_REPLY_MAGIC = 0x0003e889045565a9
NBD_FLAG_HAS_FLAGS = (1 << 0)
NBD_FLAG_READ_ONLY = (1 << 1) // Unused but kept for context if needed later
NBD_FLAG_SEND_FLUSH = (1 << 2)
NBD_FLAG_SEND_FUA = (1 << 3)
NBD_FLAG_ROTATIONAL = (1 << 4) // Unused
NBD_FLAG_SEND_TRIM = (1 << 5)
NBD_FLAG_SEND_WRITE_ZEROES = (1 << 6)
NBD_FLAG_CAN_MULTI_CONN = (1 << 8) // Unused
NBD_FLAG_FIXED_NEWSTYLE = (1 << 1)
NBD_FLAG_C_FIXED_NEWSTYLE = (1 << 1)
NBD_FLAG_C_NO_ZEROES = (1 << 2) // Unused
NBD_CMD_READ = 0
NBD_CMD_WRITE = 1
NBD_CMD_DISC = 2
NBD_CMD_FLUSH = 3
NBD_CMD_TRIM = 4
NBD_CMD_WRITE_ZEROES = 6
NBD_OPT_EXPORT_NAME = 1
NBD_OPT_ABORT = 2
NBD_OPT_LIST = 3
NBD_OPT_STARTTLS = 5 // Unused
NBD_OPT_INFO = 6
NBD_OPT_GO = 7
NBD_OPT_STRUCTURED_REPLY = 8 // Unused
NBD_OPT_LIST_META_CONTEXT = 9 // Unused
NBD_OPT_SET_META_CONTEXT = 10 // Unused
NBD_REP_ACK = 1
NBD_REP_SERVER = 2
NBD_REP_INFO = 3
NBD_REP_META_CONTEXT = 4 // Unused
NBD_REP_ERR_UNSUP = (1 << 31) + 1
NBD_REP_ERR_POLICY = (1 << 31) + 2 // Unused
NBD_REP_ERR_INVALID = (1 << 31) + 3
NBD_REP_ERR_PLATFORM = (1 << 31) + 4 // Unused
NBD_REP_ERR_TLS_REQD = (1 << 31) + 5 // Unused
NBD_REP_ERR_UNKNOWN = (1 << 31) + 6
NBD_REP_ERR_SHUTDOWN = (1 << 31) + 7 // Unused
NBD_REP_ERR_BLOCK_SIZE_REQD = (1 << 31) + 8 // Unused
NBD_REP_ERR_TOO_BIG = (1 << 31) + 9 // Unused
NBD_INFO_EXPORT = 0
NBD_INFO_NAME = 1 // Unused
NBD_INFO_DESCRIPTION = 2 // Unused
NBD_INFO_BLOCK_SIZE = 3 // Unused
)
type nbdRequest struct {
Magic uint32
Flags uint16
Type uint16
Handle, Offset uint64
Length uint32
}
type nbdReply struct {
Magic, Error uint32
Handle uint64
}
var (
storage []byte
size uint64
mutex sync.RWMutex
defaultExportName = ""
)
func main() {
sizeMB := flag.Uint64("size", 128, "Size of the block device in MiB")
listenAddr := flag.String("addr", ":10809", "Address and port to listen on")
flag.Parse()
size = *sizeMB * 1024 * 1024
if size == 0 { log.Fatal("Size cannot be zero") }
storage = make([]byte, size)
log.Printf("NBD server on %s, size %d MiB, export '%s'", *listenAddr, *sizeMB, defaultExportName)
listener, err := net.Listen("tcp", *listenAddr)
if err != nil { log.Fatalf("Failed to listen on %s: %v", *listenAddr, err) }
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil { log.Printf("Accept failed: %v", err); continue }
go handleConnection(conn)
}
}
func handleConnection(conn net.Conn) {
remoteAddr := conn.RemoteAddr().String()
log.Printf("[%s] Connected", remoteAddr)
defer conn.Close()
defer log.Printf("[%s] Disconnected", remoteAddr)
if err := handleNegotiation(conn, size, defaultExportName, remoteAddr); err != nil {
log.Printf("[%s] Handshake failed: %v", remoteAddr, err)
return
}
log.Printf("[%s] Handshake successful", remoteAddr)
if err := handleTransmission(conn, remoteAddr); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
log.Printf("[%s] Transmission error: %v", remoteAddr, err)
}
}
}
func readBE(r io.Reader, data interface{}) error { return binary.Read(r, binary.BigEndian, data) }
func writeBE(w io.Writer, data interface{}) error { return binary.Write(w, binary.BigEndian, data) }
func sendOptReply(c net.Conn, opt, repType uint32, data []byte) error {
dataLen := uint32(0); if data != nil { dataLen = uint32(len(data)) }
if writeBE(c, uint64(NBD_NEGOTIATION_REPLY_MAGIC)) != nil { return fmt.Errorf("write rep magic") }
if writeBE(c, opt) != nil { return fmt.Errorf("write rep opt") }
if writeBE(c, repType) != nil { return fmt.Errorf("write rep type") }
if writeBE(c, dataLen) != nil { return fmt.Errorf("write rep len") }
if dataLen > 0 { if _, err := c.Write(data); err != nil { return fmt.Errorf("write rep data: %w", err) } }
return nil
}
func sendSimpleOptReply(c net.Conn, opt, repType uint32) error { return sendOptReply(c, opt, repType, nil) }
func handleNegotiation(conn net.Conn, expSize uint64, expName, remoteAddr string) error {
serverFlags := uint16(NBD_FLAG_HAS_FLAGS | NBD_FLAG_FIXED_NEWSTYLE | NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA | NBD_FLAG_SEND_TRIM | NBD_FLAG_SEND_WRITE_ZEROES)
if writeBE(conn, uint64(NBD_MAGIC)) != nil { return fmt.Errorf("write NBD_MAGIC failed") }
if writeBE(conn, uint64(NBD_OPTS_MAGIC)) != nil { return fmt.Errorf("write NBD_OPTS_MAGIC failed") }
if writeBE(conn, serverFlags) != nil { return fmt.Errorf("write server flags failed") }
var clientFlags uint32
if readBE(conn, &clientFlags) != nil { return fmt.Errorf("read client flags failed") }
if (serverFlags&NBD_FLAG_FIXED_NEWSTYLE) != 0 && (clientFlags&NBD_FLAG_C_FIXED_NEWSTYLE) == 0 {
return fmt.Errorf("client did not agree to fixed newstyle")
}
transmissionFlags := uint16(NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA | NBD_FLAG_SEND_TRIM | NBD_FLAG_SEND_WRITE_ZEROES)
for {
var clientOptMagic uint64; var optType, optLen uint32
if readBE(conn, &clientOptMagic) != nil { return fmt.Errorf("read option magic failed") }
if clientOptMagic != NBD_OPTS_MAGIC { _ = sendSimpleOptReply(conn, 0, NBD_REP_ERR_INVALID); return fmt.Errorf("invalid option magic 0x%x", clientOptMagic) }
if readBE(conn, &optType) != nil { return fmt.Errorf("read option type failed") }
if readBE(conn, &optLen) != nil { return fmt.Errorf("read option length failed") }
optData := make([]byte, optLen) // TODO: Protect against large optLen
if optLen > 0 { if _, err := io.ReadFull(conn, optData); err != nil { return fmt.Errorf("read option data (type %d, len %d) failed: %w", optType, optLen, err) } }
switch optType {
case NBD_OPT_EXPORT_NAME:
reqName := string(optData)
if reqName != expName { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_UNKNOWN); continue }
if sendSimpleOptReply(conn, optType, NBD_REP_ACK) != nil { return fmt.Errorf("send export ack failed") }
if writeBE(conn, expSize) != nil { return fmt.Errorf("write export size failed") }
if writeBE(conn, transmissionFlags) != nil { return fmt.Errorf("write transmission flags failed") }
if _, err := conn.Write(make([]byte, 124)); err != nil { return fmt.Errorf("write reserved zeros failed: %w", err) }
return nil // Handshake complete
case NBD_OPT_LIST:
nameBytes, descBytes := []byte(expName), []byte("In-memory NBD")
replyData := make([]byte, 4+len(nameBytes)+4+len(descBytes))
binary.BigEndian.PutUint32(replyData[0:], uint32(len(nameBytes))); copy(replyData[4:], nameBytes)
offset := 4 + len(nameBytes)
binary.BigEndian.PutUint32(replyData[offset:], uint32(len(descBytes))); copy(replyData[offset+4:], descBytes)
if sendOptReply(conn, optType, NBD_REP_SERVER, replyData) != nil { return fmt.Errorf("send NBD_REP_SERVER failed") }
if sendSimpleOptReply(conn, optType, NBD_REP_ACK) != nil { return fmt.Errorf("send list ack failed") }
continue
case NBD_OPT_ABORT:
_ = sendSimpleOptReply(conn, optType, NBD_REP_ACK)
return fmt.Errorf("client aborted handshake")
case NBD_OPT_INFO, NBD_OPT_GO:
if optLen < 6 { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_INVALID); continue } // 4(name_len) + 2(num_info_reqs)
nameLen := binary.BigEndian.Uint32(optData[0:4])
if uint32(len(optData)) < 4+nameLen+2 { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_INVALID); continue }
reqName := string(optData[4 : 4+nameLen])
// numInfoReqs := binary.BigEndian.Uint16(optData[4+nameLen:]) // Ignored for minimal server
if reqName != expName { _ = sendSimpleOptReply(conn, optType, NBD_REP_ERR_UNKNOWN); continue }
replyData := make([]byte, 2+8+2) // type(u16)+size(u64)+flags(u16)
binary.BigEndian.PutUint16(replyData[0:], uint16(NBD_INFO_EXPORT))
binary.BigEndian.PutUint64(replyData[2:], expSize)
binary.BigEndian.PutUint16(replyData[10:], transmissionFlags)
if sendOptReply(conn, optType, NBD_REP_INFO, replyData) != nil { return fmt.Errorf("send NBD_REP_INFO failed") }
if sendSimpleOptReply(conn, optType, NBD_REP_ACK) != nil { return fmt.Errorf("send info/go ack failed") }
if optType == NBD_OPT_GO { return nil } // Handshake complete
continue
default:
log.Printf("[%s] Unsupported option: Type=%d", remoteAddr, optType)
if sendSimpleOptReply(conn, optType, NBD_REP_ERR_UNSUP) != nil { return fmt.Errorf("send NBD_REP_ERR_UNSUP failed") }
continue
}
}
}
func sendErrorReply(conn net.Conn, handle uint64, ecode syscall.Errno) error {
reply := nbdReply{Magic: NBD_SIMPLE_REPLY_MAGIC, Error: uint32(ecode), Handle: handle}
return writeBE(conn, reply)
}
func handleTransmission(conn net.Conn, remoteAddr string) error {
req := nbdRequest{}
reply := nbdReply{Magic: NBD_SIMPLE_REPLY_MAGIC}
for {
if err := readBE(conn, &req); err != nil {
if errors.Is(err, io.EOF) { return io.EOF }
return fmt.Errorf("read request header failed: %w", err)
}
if req.Magic != NBD_REQUEST_MAGIC { return fmt.Errorf("invalid request magic 0x%x", req.Magic) } // Cannot send reply reliably
reply.Handle = req.Handle; reply.Error = 0 // Prepare reply for success case
outOfBounds := req.Offset+uint64(req.Length) > size || req.Length == 0
switch req.Type {
case NBD_CMD_READ:
if outOfBounds { _ = sendErrorReply(conn, req.Handle, syscall.EINVAL); return fmt.Errorf("read out of bounds") }
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send read reply header: %w", err) }
mutex.RLock()
_, err := conn.Write(storage[req.Offset : req.Offset+uint64(req.Length)])
mutex.RUnlock()
if err != nil { return fmt.Errorf("send read data: %w", err) }
case NBD_CMD_WRITE:
if outOfBounds {
_, dErr := io.CopyN(io.Discard, conn, int64(req.Length)) // Discard data
sErr := sendErrorReply(conn, req.Handle, syscall.EINVAL)
if sErr != nil { return fmt.Errorf("send write bounds error reply: %w", sErr) }
if dErr != nil && !errors.Is(dErr, io.EOF) { return fmt.Errorf("discard data failed: %w", dErr) }
return fmt.Errorf("write out of bounds")
}
mutex.Lock()
_, err := io.ReadFull(conn, storage[req.Offset:req.Offset+uint64(req.Length)])
mutex.Unlock()
if err != nil { _ = sendErrorReply(conn, req.Handle, syscall.EIO); return fmt.Errorf("read write data: %w", err) }
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send write reply header: %w", err) }
case NBD_CMD_DISC:
return nil // No reply needed, close gracefully
case NBD_CMD_FLUSH:
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send flush reply: %w", err) } // In-memory is no-op
case NBD_CMD_TRIM, NBD_CMD_WRITE_ZEROES:
if outOfBounds { _ = sendErrorReply(conn, req.Handle, syscall.EINVAL); return fmt.Errorf("trim/zero out of bounds") }
mutex.Lock()
clear(storage[req.Offset : req.Offset+uint64(req.Length)]) // Inlined zeroRange
mutex.Unlock()
if err := writeBE(conn, reply); err != nil { return fmt.Errorf("send trim/zero reply: %w", err) }
default:
_ = sendErrorReply(conn, req.Handle, syscall.EINVAL)
return fmt.Errorf("unknown command type %d", req.Type)
}
}
}
@cjbrigato
Copy link
Author

extreme quick nbd server implementation I use for PoCs.
Exportes only default, byte slice backend. Configurable size and listening adress. That's it. Nothing more. Just enough for the NewStyle protocol.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment