diff options
Diffstat (limited to 'tls-getcerts.go')
-rw-r--r-- | tls-getcerts.go | 88 |
1 files changed, 86 insertions, 2 deletions
diff --git a/tls-getcerts.go b/tls-getcerts.go index 49e15a2..7032199 100644 --- a/tls-getcerts.go +++ b/tls-getcerts.go @@ -4,17 +4,93 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "encoding/xml" "fmt" + "io" "net" "os" + "strings" ) +type xmppStreamsFeatures struct { + XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` +} + +type xmppTlsProceed struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"` +} + +func tlsDial(snet, saddr string) (*tls.Conn, error) { + switch snet { + case "tcp": + conn, err := tls.Dial(snet, saddr, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + return nil, err + } + return conn, nil + case "xmpp": + host, _, err := net.SplitHostPort(saddr) + connTCP, err := net.Dial("tcp", saddr) + if err != nil { + return nil, err + } + + decoder := xml.NewDecoder(connTCP) + + // send <stream> start + _, err = fmt.Fprintf(connTCP, "<stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' to='%s' version='1.0'>", host) + if err != nil { + return nil, err + } + // read <stream> start + for { + t, err := decoder.Token() + if err != nil || t == nil { + return nil, err + } + if se, ok := t.(xml.StartElement); ok { + if se.Name.Local != "stream" { + return nil, xml.UnmarshalError(fmt.Sprintf("expected element of type <%s> but have <%s>", "stream", se.Name.Local)) + } + break + } + } + // read <features> + var features xmppStreamsFeatures + err = decoder.DecodeElement(&features, nil) + if err != nil { + return nil, err + } + // send <starttls> + _, err = io.WriteString(connTCP, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>") + if err != nil { + return nil, err + } + // read <proceed> + var proceed xmppTlsProceed + err = decoder.DecodeElement(&proceed, nil) + if err != nil { + return nil, err + } + + connTLS := tls.Client(connTCP, &tls.Config{InsecureSkipVerify: true}) + err = connTLS.Handshake() + if err != nil { + return nil, err + } + return connTLS, nil + default: + return nil, fmt.Errorf("Unknown TLS network: %q", snet) + } +} + func getcert(socket string) (*x509.Certificate, error){ - host, _, err := net.SplitHostPort(socket) + snet, saddr := split(socket) + host, _, err := net.SplitHostPort(saddr) if err != nil { return nil, err } - conn, err := tls.Dial("tcp", socket, &tls.Config{InsecureSkipVerify: true}) + conn, err := tlsDial(snet, saddr) if err != nil { return nil, err } @@ -34,6 +110,14 @@ func getcert(socket string) (*x509.Certificate, error){ return cert, err } +func split(socket string) (net, addr string) { + ary := strings.SplitN(socket, ":", 2) + if len(ary) == 1 { + return "tcp", ary[0] + } + return ary[0], ary[1] +} + func main() { for _, socket := range os.Args[1:] { cert, err := getcert(socket) |