Skip to content

Instantly share code, notes, and snippets.

@CarsonSlovoka
Last active September 5, 2024 21:19
Show Gist options
  • Save CarsonSlovoka/e2b8e364ac625ee03292e2b8f94e4e10 to your computer and use it in GitHub Desktop.
Save CarsonSlovoka/e2b8e364ac625ee03292e2b8f94e4e10 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/pem"
"fmt"
"golang.org/x/crypto/ssh"
"io"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"
)
func genServerKeys() {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatal(err)
}
pemBlock, err := ssh.MarshalPrivateKey(crypto.PrivateKey(privateKey), "")
if err != nil {
log.Fatal(err)
}
bsPrivateKeyPem := pem.EncodeToMemory(pemBlock)
fmt.Printf("server privKey:\n%s\n", string(bsPrivateKeyPem))
_ = os.WriteFile("keys/server_rsa", bsPrivateKeyPem, os.ModePerm)
sshPublicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
if err != nil {
log.Fatal(err)
}
publicKeyBase64 := base64.StdEncoding.EncodeToString(sshPublicKey.Marshal())
fmt.Println(publicKeyBase64)
bsPublicKeyMarshal := ssh.MarshalAuthorizedKey(sshPublicKey)
_ = os.WriteFile("keys/server_rsa.pub", bsPublicKeyMarshal, os.ModePerm)
}
func init() {
_, err1 := os.Stat("./keys/server_rsa")
_, err2 := os.Stat("./keys/server_rsa.pub")
if err1 != nil || err2 != nil {
genServerKeys()
}
bs, err := os.ReadFile("keys/server_rsa.pub")
if err != nil {
log.Fatal(err)
}
iPublicKey, _, _, _, err := ssh.ParseAuthorizedKey(bs)
if err != nil {
log.Fatal(err)
}
// fingerprint is SHA256:D8eM...
hasher := sha256.New()
hasher.Write(iPublicKey.Marshal())
fingerprint := base64.RawStdEncoding.EncodeToString(hasher.Sum(nil))
fmt.Printf("RSA key fingerprint is SHA256:%s\n", fingerprint)
}
func main() {
bsPrivateKeyPem, err := os.ReadFile("keys/server_rsa")
if err != nil {
log.Fatal(err)
}
// 配置 SSH 服務器
config := &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
_, ok := users[conn.User()]
if !ok {
return nil, fmt.Errorf("unknown user %q", conn.User())
}
if string(pass) == "xxx" {
return nil, nil
}
return nil, fmt.Errorf("invalid authtication")
},
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
user, ok := users[conn.User()]
if !ok {
return nil, fmt.Errorf("unknown user %q", conn.User())
}
if user.IPublicKey.Type() != key.Type() || // Type: ssh-rsa, ssh-rd25519, ...
!bytes.Equal(user.IPublicKey.Marshal(), key.Marshal()) {
return nil, fmt.Errorf("invalid key for %q", conn.User())
}
return &ssh.Permissions{
Extensions: map[string]string{
"user": user.Name,
},
}, nil
},
}
signer, err := ssh.ParsePrivateKey(bsPrivateKeyPem)
if err != nil {
log.Fatal(err)
}
config.AddHostKey(signer)
listener, err := net.Listen("tcp", ":12345")
if err != nil {
log.Fatal(err)
}
fmt.Println("SSH server listening on port 2222...")
for {
conn, err := listener.Accept()
log.Println("收到了一個連線")
if err != nil {
log.Printf("Failed to accept connection: %v", err)
continue
}
go handleConnection(conn, config)
}
}
func handleConnection(conn net.Conn, config *ssh.ServerConfig) {
sshConn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil {
log.Printf("Failed to handshake: %v", err)
return
}
defer func() {
_ = sshConn.Close()
}()
log.Printf("New SSH connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
_ = newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
log.Printf("Could not accept channel: %v", err)
continue
}
go handleChannel(channel, requests)
}
}
func handleChannel(
channel ssh.Channel, // 可以轉換成ssh.Session
requests <-chan *ssh.Request,
) {
defer func() {
log.Println("Connection closed")
_ = channel.Close()
}()
var err error
for req := range requests {
switch req.Type {
case "exec":
if len(req.Payload) < 4 {
_ = req.Reply(false, nil)
continue
}
cmd := string(req.Payload[4:])
log.Printf("收到的exec指令%q", cmd)
switch strings.Split(cmd, " ")[0] {
case "scp":
handleSCPSession(channel, cmd)
case "git-upload-pack": // git fetch // git-upload-pack '/xxx/qoo.git'
err = handleGitUploadPack(channel, cmd)
case "git-receive-pack": // git push // git-receive-pack '/xxx/qoo.git'
err = handleGitReceivePack(channel, cmd)
default: // 一般命令
err = executeCommand(channel, cmd) // ssh -p 12345 [email protected] hello
}
if err != nil {
log.Println(err)
}
_, _ = channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0})
_ = req.Reply(true, nil)
return
case "shell":
_ = req.Reply(true, nil)
_, _ = io.WriteString(channel, "Welcome to the Go SSH server!\r\n")
_, _ = io.WriteString(channel, "Type 'exit' to disconnect.\r\n")
handleShell(channel)
return
}
}
}
func executeCommand(channel ssh.Channel, cmd string) error {
switch cmd {
case "list":
_, _ = io.WriteString(channel, "Available commands: list, hello, time\r\n")
case "hello":
_, _ = io.WriteString(channel, "Hello, SSH user!\r\n")
case "time":
_, _ = io.WriteString(channel, fmt.Sprintf("Current server time: %s\r\n", time.Now().Format(time.RFC3339)))
default:
return fmt.Errorf("unknown command: %s", cmd)
}
return nil
}
func handleShell(channel ssh.Channel) {
try := 0
for {
_, _ = io.WriteString(channel, "> ")
var cmd string
_, _ = fmt.Fscanf(channel, "%s\n", &cmd)
cmd = strings.TrimSpace(cmd)
if cmd == "exit" {
return
}
if err := executeCommand(channel, cmd); err != nil {
_, _ = io.WriteString(channel, fmt.Sprintf("Unknown command: %s\r\n", err))
try++
if try >= 3 {
_, _ = io.WriteString(channel, "過多的嘗試,終止shell")
return
}
} else {
try = 0
}
}
}
// https://www.mkssoftware.com/docs/man1/scp.1.asp
func handleSCPSession(channel ssh.Channel, command string) {
parts := strings.Fields(command)
if len(parts) < 3 {
_, _ = fmt.Fprintf(channel, "scp: invalid command\n")
return
}
isFMode := strings.Contains(command, "-f")
isTMode := strings.Contains(command, "-t")
isRecursive := strings.Contains(command, "-r")
path := parts[len(parts)-1] // 取最後一筆得到路徑
var err error
if isFMode { // -f
// Download file
if isRecursive {
handleDownloadDir(channel,
filepath.Dir(path), filepath.Base(path),
) // scp -r -f C:/../xxxDir
} else {
handleDownloadFile(channel, path)
}
} else if isTMode { // -t
// Upload file
if isRecursive {
err = handleUploadDir(channel, path)
} else {
err = handleUploadFile(channel, path, nil)
}
} else {
_, _ = fmt.Fprintf(channel, "Unsupported SCP option: %s\n", parts[1])
}
if err != nil {
log.Printf("error %s\n", err)
_, _ = fmt.Fprintf(channel, "%s\n", err)
}
}
func handleDownloadFile(channel ssh.Channel, path string) {
log.Printf("Attempting to download file: %s", path)
file, err := os.Open(path)
if err != nil {
_, _ = fmt.Fprintf(channel, "Failed to open file: %v\n", err)
return
}
defer func() {
_ = file.Close()
}()
stat, err := file.Stat()
if err != nil {
_, _ = fmt.Fprintf(channel, "Failed to stat file: %v\n", err)
return
}
log.Printf("Sending file info: mode=%o, size=%d, name=%s",
stat.Mode().Perm(), stat.Size(), filepath.Base(path)) // %o 表示8進位
_, _ = fmt.Fprintf(channel, "C%04o %d %s\n", stat.Mode().Perm(), // C0700 128123 C:\\xxx\my.file
stat.Size(),
filepath.Base(path),
)
bytesCopied, err := io.Copy(channel, file)
if err != nil {
log.Printf("Error during file transfer: %v", err)
}
log.Printf("Bytes copied: %d", bytesCopied)
_, _ = channel.Write([]byte{0}) // Transfer complete
log.Printf("Download complete")
}
// handleDownloadDir
func handleDownloadDir(channel ssh.Channel,
basePath, relativePath string,
) {
fullPath := filepath.Join(basePath, relativePath)
log.Printf("Attempting to download directory: %s", fullPath)
entries, err := os.ReadDir(fullPath)
if err != nil {
_, _ = fmt.Fprintf(channel, "Failed to read directory: %v\n", err)
return
}
// 先標記為目錄,D%04o %d %s
stat, err := os.Stat(fullPath)
if err != nil {
_, _ = fmt.Fprintf(channel, "Failed to stat directory: %v\n", err)
return
}
// _, _ = fmt.Fprintf(channel, "D0755 0 %s\n", filepath.Base(relativePath))
_, _ = fmt.Fprintf(channel, "D%04o 0 %s\n", stat.Mode().Perm(),
filepath.Base(relativePath),
)
for _, entry := range entries {
entryPath := filepath.Join(relativePath, entry.Name())
fullEntryPath := filepath.Join(basePath, entryPath)
entryInfo, err := entry.Info()
if err != nil {
log.Printf("Failed to get info for %s: %v", fullEntryPath, err)
continue
}
if entryInfo.IsDir() {
handleDownloadDir(channel, basePath, entryPath)
} else {
handleDownloadFile(channel, fullEntryPath)
}
}
log.Printf("Download of directory %q complete", fullPath)
}
func handleUploadFile(channel ssh.Channel, path string, header *Header) error {
_, _ = channel.Write([]byte{0}) // Ready to receive
if header == nil {
var buf [1024]byte
n, err := channel.Read(buf[:])
if err != nil {
return fmt.Errorf("failed to read file info: %w", err)
}
header, err = parseHeader(buf[:n], "")
if err != nil {
return err
}
}
if err := os.MkdirAll(filepath.Dir(path), header.Mode); err != nil {
return err
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, header.Mode)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer func() {
_ = file.Close()
}()
if len(header.Content) > 0 {
if _, err = io.WriteString(file, header.Content); err != nil {
return err
}
_, _ = channel.Write([]byte{0})
return nil
}
_, _ = channel.Write([]byte{0})
_, err = io.CopyN(file, channel, header.Size)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
buffer := make([]byte, 1)
_, err = channel.Read(buffer)
if err != nil || buffer[0] != 0 {
log.Printf("Failed to receive transfer completion signal")
_, _ = fmt.Fprintf(channel.Stderr(), "Failed to receive transfer completion signal\n")
return err
}
_, _ = channel.Write([]byte{0})
log.Printf("Upload complete")
return nil
}
type Header struct {
// Tag
// 'C' (67) (File) or 'D' (68) (Dir)
Tag byte
Mode os.FileMode
Size int64
RelativePath string
Content string // for Tag is C
}
func (h *Header) String() string {
return fmt.Sprintf("%c%04o %d %s\n", h.Tag, h.Mode, h.Size, h.RelativePath)
}
func parseHeader(buf []byte, basePath string) (*Header, error) {
if len(buf) == 0 {
return nil, nil
}
var header Header
header.Tag = buf[0]
var parts []string
var relativePath string
var isSub bool
switch header.Tag {
case 'C':
parts = strings.SplitN(string(buf), " ", 3)
if len(parts) != 3 {
return nil, fmt.Errorf("invalid format not 'C0xxx size name' or 'C0xxx size name\\Ncontent'")
}
subParts := strings.SplitN(parts[2], "\n", 2)
header.Content = subParts[1]
relativePath = subParts[0]
case 'D':
parts = strings.SplitN(string(buf), " ", 3)
if len(parts) != 3 {
return nil, fmt.Errorf("invalid format not 'D0xxx size name'")
}
subParts := strings.SplitN(parts[2], "\n", 2)
if len(subParts[1]) > 0 {
basePath = filepath.Join(basePath, subParts[0])
h, err := parseHeader(
[]byte(subParts[1]),
basePath,
)
if err != nil {
return nil, err
}
relativePath = filepath.Join(basePath, h.RelativePath)
header = *h
isSub = true
} else {
relativePath = subParts[0]
}
default:
return nil, fmt.Errorf("invalid header tag: %c", header.Tag)
}
if !isSub {
var err error
modeInt, err := strconv.ParseUint(parts[0][1:], 8, 32)
if err != nil {
return nil, err
}
header.Mode = os.FileMode(modeInt)
header.Size, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return nil, err
}
}
header.RelativePath = relativePath
return &header, nil
}
// handleUploadDir
func handleUploadDir(channel ssh.Channel,
path string,
) error {
_, _ = channel.Write([]byte{0})
var buf [1024]byte
n, err := channel.Read(buf[:])
if err != nil {
return fmt.Errorf("failed to read directory info: %w", err)
}
var header *Header
header, err = parseHeader(buf[:n], "")
if err != nil {
return err
}
dirName := header.RelativePath
err = os.MkdirAll(filepath.Join(path, dirName), header.Mode)
if err != nil {
_, _ = fmt.Fprintf(channel, "Failed to create directory: %v\n", err)
return nil
}
_, _ = channel.Write([]byte{0})
for {
n, err = channel.Read(buf[:])
header, err = parseHeader(buf[:n], dirName)
if err != nil {
return err
}
if n == 0 {
break
}
if header.Tag == 'C' {
/*
D0777 0 sprite
前一層如果是
D0777 0 sub
C0666 15 qoo.txt
qoo1
接下來會是
D0777 0 sub2
C0666 9 apple.txt
o1
todo: 但是目前的寫法,他是把sub2拆開,實際上sub2應該要在sub底下,也就是parseHeader如果遇到有子內容,還要考慮父資料夾不能只檔案
*/
if err = handleUploadFile(channel, filepath.Join(path, dirName, header.RelativePath), header); err != nil {
return err
}
} else if header.Tag == 'D' {
err = handleUploadDir(channel, filepath.Join(path, dirName, header.RelativePath))
if err != nil {
return err
}
} else {
return fmt.Errorf("unknown header Tag: %s", string(header.Tag))
}
}
buffer := make([]byte, 1)
_, err = channel.Read(buffer)
if err != nil || buffer[0] != 0 {
return fmt.Errorf("failed to receive transfer completion signal")
}
_, _ = channel.Write([]byte{0})
return nil
}
func resolvePath(p string) string {
// '/github/userXX/qoo.git' => /github/userXX/qoo.git
return strings.Trim(p, "'")
}
func handleGitUploadPack(channel ssh.Channel, cmd string) error {
parts := strings.Fields(cmd)
if len(parts) < 2 {
return fmt.Errorf("invalid git-upload-pack command")
}
repoPath := "." + resolvePath(parts[1])
gitCmd := exec.Command("git", "upload-pack", repoPath)
gitCmd.Stdout = channel
gitCmd.Stderr = channel.Stderr()
stdin, err := gitCmd.StdinPipe()
if err != nil {
return fmt.Errorf("error creating stdin pipe: %w", err)
}
if err = gitCmd.Start(); err != nil {
return fmt.Errorf("error starting git-upload-pack: %w", err)
}
go func() {
_, _ = io.Copy(stdin, channel)
_ = stdin.Close()
}()
if err = gitCmd.Wait(); err != nil {
return fmt.Errorf("error waiting for git-upload-pack: %w", err)
}
return nil
}
func handleGitReceivePack(channel ssh.Channel, cmd string) error {
parts := strings.Fields(cmd)
if len(parts) < 2 {
return fmt.Errorf("invalid git-receive-pack command")
}
repoPath := "." + resolvePath(parts[1])
gitCmd := exec.Command("git", "receive-pack", repoPath)
gitCmd.Stdout = channel
gitCmd.Stderr = channel.Stderr()
stdin, err := gitCmd.StdinPipe()
if err != nil {
return fmt.Errorf("error creating stdin pipe: %w", err)
}
if err = gitCmd.Start(); err != nil {
return fmt.Errorf("error starting git-receive-pack: %w", err)
}
go func() {
_, _ = io.Copy(stdin, channel)
_ = stdin.Close()
}()
if err = gitCmd.Wait(); err != nil {
return fmt.Errorf("error waiting for git-receive-pack: %w", err)
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment