Last active
September 5, 2024 21:19
-
-
Save CarsonSlovoka/e2b8e364ac625ee03292e2b8f94e4e10 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"crypto" | |
"crypto/rand" | |
"crypto/rsa" | |
"crypto/sha256" | |
"encoding/base64" | |
"encoding/pem" | |
"fmt" | |
"golang.org/x/crypto/ssh" | |
"io" | |
"log" | |
"net" | |
"os" | |
"os/exec" | |
"path/filepath" | |
"strconv" | |
"strings" | |
"time" | |
) | |
func genServerKeys() { | |
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) | |
if err != nil { | |
log.Fatal(err) | |
} | |
pemBlock, err := ssh.MarshalPrivateKey(crypto.PrivateKey(privateKey), "") | |
if err != nil { | |
log.Fatal(err) | |
} | |
bsPrivateKeyPem := pem.EncodeToMemory(pemBlock) | |
fmt.Printf("server privKey:\n%s\n", string(bsPrivateKeyPem)) | |
_ = os.WriteFile("keys/server_rsa", bsPrivateKeyPem, os.ModePerm) | |
sshPublicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) | |
if err != nil { | |
log.Fatal(err) | |
} | |
publicKeyBase64 := base64.StdEncoding.EncodeToString(sshPublicKey.Marshal()) | |
fmt.Println(publicKeyBase64) | |
bsPublicKeyMarshal := ssh.MarshalAuthorizedKey(sshPublicKey) | |
_ = os.WriteFile("keys/server_rsa.pub", bsPublicKeyMarshal, os.ModePerm) | |
} | |
func init() { | |
_, err1 := os.Stat("./keys/server_rsa") | |
_, err2 := os.Stat("./keys/server_rsa.pub") | |
if err1 != nil || err2 != nil { | |
genServerKeys() | |
} | |
bs, err := os.ReadFile("keys/server_rsa.pub") | |
if err != nil { | |
log.Fatal(err) | |
} | |
iPublicKey, _, _, _, err := ssh.ParseAuthorizedKey(bs) | |
if err != nil { | |
log.Fatal(err) | |
} | |
// fingerprint is SHA256:D8eM... | |
hasher := sha256.New() | |
hasher.Write(iPublicKey.Marshal()) | |
fingerprint := base64.RawStdEncoding.EncodeToString(hasher.Sum(nil)) | |
fmt.Printf("RSA key fingerprint is SHA256:%s\n", fingerprint) | |
} | |
func main() { | |
bsPrivateKeyPem, err := os.ReadFile("keys/server_rsa") | |
if err != nil { | |
log.Fatal(err) | |
} | |
// 配置 SSH 服務器 | |
config := &ssh.ServerConfig{ | |
PasswordCallback: func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { | |
_, ok := users[conn.User()] | |
if !ok { | |
return nil, fmt.Errorf("unknown user %q", conn.User()) | |
} | |
if string(pass) == "xxx" { | |
return nil, nil | |
} | |
return nil, fmt.Errorf("invalid authtication") | |
}, | |
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { | |
user, ok := users[conn.User()] | |
if !ok { | |
return nil, fmt.Errorf("unknown user %q", conn.User()) | |
} | |
if user.IPublicKey.Type() != key.Type() || // Type: ssh-rsa, ssh-rd25519, ... | |
!bytes.Equal(user.IPublicKey.Marshal(), key.Marshal()) { | |
return nil, fmt.Errorf("invalid key for %q", conn.User()) | |
} | |
return &ssh.Permissions{ | |
Extensions: map[string]string{ | |
"user": user.Name, | |
}, | |
}, nil | |
}, | |
} | |
signer, err := ssh.ParsePrivateKey(bsPrivateKeyPem) | |
if err != nil { | |
log.Fatal(err) | |
} | |
config.AddHostKey(signer) | |
listener, err := net.Listen("tcp", ":12345") | |
if err != nil { | |
log.Fatal(err) | |
} | |
fmt.Println("SSH server listening on port 2222...") | |
for { | |
conn, err := listener.Accept() | |
log.Println("收到了一個連線") | |
if err != nil { | |
log.Printf("Failed to accept connection: %v", err) | |
continue | |
} | |
go handleConnection(conn, config) | |
} | |
} | |
func handleConnection(conn net.Conn, config *ssh.ServerConfig) { | |
sshConn, chans, reqs, err := ssh.NewServerConn(conn, config) | |
if err != nil { | |
log.Printf("Failed to handshake: %v", err) | |
return | |
} | |
defer func() { | |
_ = sshConn.Close() | |
}() | |
log.Printf("New SSH connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion()) | |
go ssh.DiscardRequests(reqs) | |
for newChannel := range chans { | |
if newChannel.ChannelType() != "session" { | |
_ = newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") | |
continue | |
} | |
channel, requests, err := newChannel.Accept() | |
if err != nil { | |
log.Printf("Could not accept channel: %v", err) | |
continue | |
} | |
go handleChannel(channel, requests) | |
} | |
} | |
func handleChannel( | |
channel ssh.Channel, // 可以轉換成ssh.Session | |
requests <-chan *ssh.Request, | |
) { | |
defer func() { | |
log.Println("Connection closed") | |
_ = channel.Close() | |
}() | |
var err error | |
for req := range requests { | |
switch req.Type { | |
case "exec": | |
if len(req.Payload) < 4 { | |
_ = req.Reply(false, nil) | |
continue | |
} | |
cmd := string(req.Payload[4:]) | |
log.Printf("收到的exec指令%q", cmd) | |
switch strings.Split(cmd, " ")[0] { | |
case "scp": | |
handleSCPSession(channel, cmd) | |
case "git-upload-pack": // git fetch // git-upload-pack '/xxx/qoo.git' | |
err = handleGitUploadPack(channel, cmd) | |
case "git-receive-pack": // git push // git-receive-pack '/xxx/qoo.git' | |
err = handleGitReceivePack(channel, cmd) | |
default: // 一般命令 | |
err = executeCommand(channel, cmd) // ssh -p 12345 [email protected] hello | |
} | |
if err != nil { | |
log.Println(err) | |
} | |
_, _ = channel.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) | |
_ = req.Reply(true, nil) | |
return | |
case "shell": | |
_ = req.Reply(true, nil) | |
_, _ = io.WriteString(channel, "Welcome to the Go SSH server!\r\n") | |
_, _ = io.WriteString(channel, "Type 'exit' to disconnect.\r\n") | |
handleShell(channel) | |
return | |
} | |
} | |
} | |
func executeCommand(channel ssh.Channel, cmd string) error { | |
switch cmd { | |
case "list": | |
_, _ = io.WriteString(channel, "Available commands: list, hello, time\r\n") | |
case "hello": | |
_, _ = io.WriteString(channel, "Hello, SSH user!\r\n") | |
case "time": | |
_, _ = io.WriteString(channel, fmt.Sprintf("Current server time: %s\r\n", time.Now().Format(time.RFC3339))) | |
default: | |
return fmt.Errorf("unknown command: %s", cmd) | |
} | |
return nil | |
} | |
func handleShell(channel ssh.Channel) { | |
try := 0 | |
for { | |
_, _ = io.WriteString(channel, "> ") | |
var cmd string | |
_, _ = fmt.Fscanf(channel, "%s\n", &cmd) | |
cmd = strings.TrimSpace(cmd) | |
if cmd == "exit" { | |
return | |
} | |
if err := executeCommand(channel, cmd); err != nil { | |
_, _ = io.WriteString(channel, fmt.Sprintf("Unknown command: %s\r\n", err)) | |
try++ | |
if try >= 3 { | |
_, _ = io.WriteString(channel, "過多的嘗試,終止shell") | |
return | |
} | |
} else { | |
try = 0 | |
} | |
} | |
} | |
// https://www.mkssoftware.com/docs/man1/scp.1.asp | |
func handleSCPSession(channel ssh.Channel, command string) { | |
parts := strings.Fields(command) | |
if len(parts) < 3 { | |
_, _ = fmt.Fprintf(channel, "scp: invalid command\n") | |
return | |
} | |
isFMode := strings.Contains(command, "-f") | |
isTMode := strings.Contains(command, "-t") | |
isRecursive := strings.Contains(command, "-r") | |
path := parts[len(parts)-1] // 取最後一筆得到路徑 | |
var err error | |
if isFMode { // -f | |
// Download file | |
if isRecursive { | |
handleDownloadDir(channel, | |
filepath.Dir(path), filepath.Base(path), | |
) // scp -r -f C:/../xxxDir | |
} else { | |
handleDownloadFile(channel, path) | |
} | |
} else if isTMode { // -t | |
// Upload file | |
if isRecursive { | |
err = handleUploadDir(channel, path) | |
} else { | |
err = handleUploadFile(channel, path, nil) | |
} | |
} else { | |
_, _ = fmt.Fprintf(channel, "Unsupported SCP option: %s\n", parts[1]) | |
} | |
if err != nil { | |
log.Printf("error %s\n", err) | |
_, _ = fmt.Fprintf(channel, "%s\n", err) | |
} | |
} | |
func handleDownloadFile(channel ssh.Channel, path string) { | |
log.Printf("Attempting to download file: %s", path) | |
file, err := os.Open(path) | |
if err != nil { | |
_, _ = fmt.Fprintf(channel, "Failed to open file: %v\n", err) | |
return | |
} | |
defer func() { | |
_ = file.Close() | |
}() | |
stat, err := file.Stat() | |
if err != nil { | |
_, _ = fmt.Fprintf(channel, "Failed to stat file: %v\n", err) | |
return | |
} | |
log.Printf("Sending file info: mode=%o, size=%d, name=%s", | |
stat.Mode().Perm(), stat.Size(), filepath.Base(path)) // %o 表示8進位 | |
_, _ = fmt.Fprintf(channel, "C%04o %d %s\n", stat.Mode().Perm(), // C0700 128123 C:\\xxx\my.file | |
stat.Size(), | |
filepath.Base(path), | |
) | |
bytesCopied, err := io.Copy(channel, file) | |
if err != nil { | |
log.Printf("Error during file transfer: %v", err) | |
} | |
log.Printf("Bytes copied: %d", bytesCopied) | |
_, _ = channel.Write([]byte{0}) // Transfer complete | |
log.Printf("Download complete") | |
} | |
// handleDownloadDir | |
func handleDownloadDir(channel ssh.Channel, | |
basePath, relativePath string, | |
) { | |
fullPath := filepath.Join(basePath, relativePath) | |
log.Printf("Attempting to download directory: %s", fullPath) | |
entries, err := os.ReadDir(fullPath) | |
if err != nil { | |
_, _ = fmt.Fprintf(channel, "Failed to read directory: %v\n", err) | |
return | |
} | |
// 先標記為目錄,D%04o %d %s | |
stat, err := os.Stat(fullPath) | |
if err != nil { | |
_, _ = fmt.Fprintf(channel, "Failed to stat directory: %v\n", err) | |
return | |
} | |
// _, _ = fmt.Fprintf(channel, "D0755 0 %s\n", filepath.Base(relativePath)) | |
_, _ = fmt.Fprintf(channel, "D%04o 0 %s\n", stat.Mode().Perm(), | |
filepath.Base(relativePath), | |
) | |
for _, entry := range entries { | |
entryPath := filepath.Join(relativePath, entry.Name()) | |
fullEntryPath := filepath.Join(basePath, entryPath) | |
entryInfo, err := entry.Info() | |
if err != nil { | |
log.Printf("Failed to get info for %s: %v", fullEntryPath, err) | |
continue | |
} | |
if entryInfo.IsDir() { | |
handleDownloadDir(channel, basePath, entryPath) | |
} else { | |
handleDownloadFile(channel, fullEntryPath) | |
} | |
} | |
log.Printf("Download of directory %q complete", fullPath) | |
} | |
func handleUploadFile(channel ssh.Channel, path string, header *Header) error { | |
_, _ = channel.Write([]byte{0}) // Ready to receive | |
if header == nil { | |
var buf [1024]byte | |
n, err := channel.Read(buf[:]) | |
if err != nil { | |
return fmt.Errorf("failed to read file info: %w", err) | |
} | |
header, err = parseHeader(buf[:n], "") | |
if err != nil { | |
return err | |
} | |
} | |
if err := os.MkdirAll(filepath.Dir(path), header.Mode); err != nil { | |
return err | |
} | |
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, header.Mode) | |
if err != nil { | |
return fmt.Errorf("failed to create file: %w", err) | |
} | |
defer func() { | |
_ = file.Close() | |
}() | |
if len(header.Content) > 0 { | |
if _, err = io.WriteString(file, header.Content); err != nil { | |
return err | |
} | |
_, _ = channel.Write([]byte{0}) | |
return nil | |
} | |
_, _ = channel.Write([]byte{0}) | |
_, err = io.CopyN(file, channel, header.Size) | |
if err != nil { | |
return fmt.Errorf("failed to write file: %w", err) | |
} | |
buffer := make([]byte, 1) | |
_, err = channel.Read(buffer) | |
if err != nil || buffer[0] != 0 { | |
log.Printf("Failed to receive transfer completion signal") | |
_, _ = fmt.Fprintf(channel.Stderr(), "Failed to receive transfer completion signal\n") | |
return err | |
} | |
_, _ = channel.Write([]byte{0}) | |
log.Printf("Upload complete") | |
return nil | |
} | |
type Header struct { | |
// Tag | |
// 'C' (67) (File) or 'D' (68) (Dir) | |
Tag byte | |
Mode os.FileMode | |
Size int64 | |
RelativePath string | |
Content string // for Tag is C | |
} | |
func (h *Header) String() string { | |
return fmt.Sprintf("%c%04o %d %s\n", h.Tag, h.Mode, h.Size, h.RelativePath) | |
} | |
func parseHeader(buf []byte, basePath string) (*Header, error) { | |
if len(buf) == 0 { | |
return nil, nil | |
} | |
var header Header | |
header.Tag = buf[0] | |
var parts []string | |
var relativePath string | |
var isSub bool | |
switch header.Tag { | |
case 'C': | |
parts = strings.SplitN(string(buf), " ", 3) | |
if len(parts) != 3 { | |
return nil, fmt.Errorf("invalid format not 'C0xxx size name' or 'C0xxx size name\\Ncontent'") | |
} | |
subParts := strings.SplitN(parts[2], "\n", 2) | |
header.Content = subParts[1] | |
relativePath = subParts[0] | |
case 'D': | |
parts = strings.SplitN(string(buf), " ", 3) | |
if len(parts) != 3 { | |
return nil, fmt.Errorf("invalid format not 'D0xxx size name'") | |
} | |
subParts := strings.SplitN(parts[2], "\n", 2) | |
if len(subParts[1]) > 0 { | |
basePath = filepath.Join(basePath, subParts[0]) | |
h, err := parseHeader( | |
[]byte(subParts[1]), | |
basePath, | |
) | |
if err != nil { | |
return nil, err | |
} | |
relativePath = filepath.Join(basePath, h.RelativePath) | |
header = *h | |
isSub = true | |
} else { | |
relativePath = subParts[0] | |
} | |
default: | |
return nil, fmt.Errorf("invalid header tag: %c", header.Tag) | |
} | |
if !isSub { | |
var err error | |
modeInt, err := strconv.ParseUint(parts[0][1:], 8, 32) | |
if err != nil { | |
return nil, err | |
} | |
header.Mode = os.FileMode(modeInt) | |
header.Size, err = strconv.ParseInt(parts[1], 10, 64) | |
if err != nil { | |
return nil, err | |
} | |
} | |
header.RelativePath = relativePath | |
return &header, nil | |
} | |
// handleUploadDir | |
func handleUploadDir(channel ssh.Channel, | |
path string, | |
) error { | |
_, _ = channel.Write([]byte{0}) | |
var buf [1024]byte | |
n, err := channel.Read(buf[:]) | |
if err != nil { | |
return fmt.Errorf("failed to read directory info: %w", err) | |
} | |
var header *Header | |
header, err = parseHeader(buf[:n], "") | |
if err != nil { | |
return err | |
} | |
dirName := header.RelativePath | |
err = os.MkdirAll(filepath.Join(path, dirName), header.Mode) | |
if err != nil { | |
_, _ = fmt.Fprintf(channel, "Failed to create directory: %v\n", err) | |
return nil | |
} | |
_, _ = channel.Write([]byte{0}) | |
for { | |
n, err = channel.Read(buf[:]) | |
header, err = parseHeader(buf[:n], dirName) | |
if err != nil { | |
return err | |
} | |
if n == 0 { | |
break | |
} | |
if header.Tag == 'C' { | |
/* | |
D0777 0 sprite | |
前一層如果是 | |
D0777 0 sub | |
C0666 15 qoo.txt | |
qoo1 | |
接下來會是 | |
D0777 0 sub2 | |
C0666 9 apple.txt | |
o1 | |
todo: 但是目前的寫法,他是把sub2拆開,實際上sub2應該要在sub底下,也就是parseHeader如果遇到有子內容,還要考慮父資料夾不能只檔案 | |
*/ | |
if err = handleUploadFile(channel, filepath.Join(path, dirName, header.RelativePath), header); err != nil { | |
return err | |
} | |
} else if header.Tag == 'D' { | |
err = handleUploadDir(channel, filepath.Join(path, dirName, header.RelativePath)) | |
if err != nil { | |
return err | |
} | |
} else { | |
return fmt.Errorf("unknown header Tag: %s", string(header.Tag)) | |
} | |
} | |
buffer := make([]byte, 1) | |
_, err = channel.Read(buffer) | |
if err != nil || buffer[0] != 0 { | |
return fmt.Errorf("failed to receive transfer completion signal") | |
} | |
_, _ = channel.Write([]byte{0}) | |
return nil | |
} | |
func resolvePath(p string) string { | |
// '/github/userXX/qoo.git' => /github/userXX/qoo.git | |
return strings.Trim(p, "'") | |
} | |
func handleGitUploadPack(channel ssh.Channel, cmd string) error { | |
parts := strings.Fields(cmd) | |
if len(parts) < 2 { | |
return fmt.Errorf("invalid git-upload-pack command") | |
} | |
repoPath := "." + resolvePath(parts[1]) | |
gitCmd := exec.Command("git", "upload-pack", repoPath) | |
gitCmd.Stdout = channel | |
gitCmd.Stderr = channel.Stderr() | |
stdin, err := gitCmd.StdinPipe() | |
if err != nil { | |
return fmt.Errorf("error creating stdin pipe: %w", err) | |
} | |
if err = gitCmd.Start(); err != nil { | |
return fmt.Errorf("error starting git-upload-pack: %w", err) | |
} | |
go func() { | |
_, _ = io.Copy(stdin, channel) | |
_ = stdin.Close() | |
}() | |
if err = gitCmd.Wait(); err != nil { | |
return fmt.Errorf("error waiting for git-upload-pack: %w", err) | |
} | |
return nil | |
} | |
func handleGitReceivePack(channel ssh.Channel, cmd string) error { | |
parts := strings.Fields(cmd) | |
if len(parts) < 2 { | |
return fmt.Errorf("invalid git-receive-pack command") | |
} | |
repoPath := "." + resolvePath(parts[1]) | |
gitCmd := exec.Command("git", "receive-pack", repoPath) | |
gitCmd.Stdout = channel | |
gitCmd.Stderr = channel.Stderr() | |
stdin, err := gitCmd.StdinPipe() | |
if err != nil { | |
return fmt.Errorf("error creating stdin pipe: %w", err) | |
} | |
if err = gitCmd.Start(); err != nil { | |
return fmt.Errorf("error starting git-receive-pack: %w", err) | |
} | |
go func() { | |
_, _ = io.Copy(stdin, channel) | |
_ = stdin.Close() | |
}() | |
if err = gitCmd.Wait(); err != nil { | |
return fmt.Errorf("error waiting for git-receive-pack: %w", err) | |
} | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment