summaryrefslogtreecommitdiff
path: root/tls-getcerts.go
diff options
context:
space:
mode:
Diffstat (limited to 'tls-getcerts.go')
-rw-r--r--tls-getcerts.go88
1 files changed, 86 insertions, 2 deletions
diff --git a/tls-getcerts.go b/tls-getcerts.go
index 49e15a2..7032199 100644
--- a/tls-getcerts.go
+++ b/tls-getcerts.go
@@ -4,17 +4,93 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
+ "encoding/xml"
"fmt"
+ "io"
"net"
"os"
+ "strings"
)
+type xmppStreamsFeatures struct {
+ XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
+}
+
+type xmppTlsProceed struct {
+ XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"`
+}
+
+func tlsDial(snet, saddr string) (*tls.Conn, error) {
+ switch snet {
+ case "tcp":
+ conn, err := tls.Dial(snet, saddr, &tls.Config{InsecureSkipVerify: true})
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+ case "xmpp":
+ host, _, err := net.SplitHostPort(saddr)
+ connTCP, err := net.Dial("tcp", saddr)
+ if err != nil {
+ return nil, err
+ }
+
+ decoder := xml.NewDecoder(connTCP)
+
+ // send <stream> start
+ _, err = fmt.Fprintf(connTCP, "<stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' to='%s' version='1.0'>", host)
+ if err != nil {
+ return nil, err
+ }
+ // read <stream> start
+ for {
+ t, err := decoder.Token()
+ if err != nil || t == nil {
+ return nil, err
+ }
+ if se, ok := t.(xml.StartElement); ok {
+ if se.Name.Local != "stream" {
+ return nil, xml.UnmarshalError(fmt.Sprintf("expected element of type <%s> but have <%s>", "stream", se.Name.Local))
+ }
+ break
+ }
+ }
+ // read <features>
+ var features xmppStreamsFeatures
+ err = decoder.DecodeElement(&features, nil)
+ if err != nil {
+ return nil, err
+ }
+ // send <starttls>
+ _, err = io.WriteString(connTCP, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
+ if err != nil {
+ return nil, err
+ }
+ // read <proceed>
+ var proceed xmppTlsProceed
+ err = decoder.DecodeElement(&proceed, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ connTLS := tls.Client(connTCP, &tls.Config{InsecureSkipVerify: true})
+ err = connTLS.Handshake()
+ if err != nil {
+ return nil, err
+ }
+ return connTLS, nil
+ default:
+ return nil, fmt.Errorf("Unknown TLS network: %q", snet)
+ }
+}
+
func getcert(socket string) (*x509.Certificate, error){
- host, _, err := net.SplitHostPort(socket)
+ snet, saddr := split(socket)
+ host, _, err := net.SplitHostPort(saddr)
if err != nil {
return nil, err
}
- conn, err := tls.Dial("tcp", socket, &tls.Config{InsecureSkipVerify: true})
+ conn, err := tlsDial(snet, saddr)
if err != nil {
return nil, err
}
@@ -34,6 +110,14 @@ func getcert(socket string) (*x509.Certificate, error){
return cert, err
}
+func split(socket string) (net, addr string) {
+ ary := strings.SplitN(socket, ":", 2)
+ if len(ary) == 1 {
+ return "tcp", ary[0]
+ }
+ return ary[0], ary[1]
+}
+
func main() {
for _, socket := range os.Args[1:] {
cert, err := getcert(socket)