diff --git a/src/main/java/net/sourceforge/plantuml/servlet/ProxyServlet.java b/src/main/java/net/sourceforge/plantuml/servlet/ProxyServlet.java index 2902568..566be37 100644 --- a/src/main/java/net/sourceforge/plantuml/servlet/ProxyServlet.java +++ b/src/main/java/net/sourceforge/plantuml/servlet/ProxyServlet.java @@ -23,26 +23,27 @@ */ package net.sourceforge.plantuml.servlet; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; import java.net.URL; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import HTTPClient.CookieModule; -import HTTPClient.HTTPConnection; -import HTTPClient.HTTPResponse; -import HTTPClient.ModuleException; -import HTTPClient.ParseException; - import net.sourceforge.plantuml.FileFormat; import net.sourceforge.plantuml.FileFormatOption; import net.sourceforge.plantuml.SourceStringReader; +import java.security.cert.Certificate; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLPeerUnverifiedException; + /* * Proxy servlet of the webapp. * This servlet retrieves the diagram source of a web resource (web html page) @@ -58,33 +59,44 @@ public class ProxyServlet extends HttpServlet { final String source = request.getParameter("src"); final String index = request.getParameter("idx"); - - // TODO Check if the src URL is valid - + final URL srcUrl; + // Check if the src URL is valid + try { + srcUrl = new URL(source); + } catch (MalformedURLException mue) { + mue.printStackTrace(); + return; + } + // generate the response - SourceStringReader reader = new SourceStringReader(getSource(source)); + String diagmarkup = getSource(srcUrl); + System.out.println("getSource=>" + diagmarkup); + SourceStringReader reader = new SourceStringReader(diagmarkup); int n = index == null ? 0 : Integer.parseInt(index); + reader.generateImage(response.getOutputStream(), n, new FileFormatOption(getOutputFormat(), false)); } - private String getSource(String uri) throws IOException { - CookieModule.setCookiePolicyHandler(null); - - final Pattern p = Pattern.compile("http://[^/]+(/?.*)"); - final Matcher m = p.matcher(uri); - if (m.find() == false) { - throw new IOException(uri); - } - final URL url = new URL(uri); - final HTTPConnection httpConnection = new HTTPConnection(url); + private String getSource(URL url) throws IOException { + String line; + BufferedReader rd; + StringBuilder sb; try { - final HTTPResponse resp = httpConnection.Get(m.group(1)); - return resp.getText(); - } catch (ModuleException e) { - throw new IOException(e.toString()); - } catch (ParseException e) { - throw new IOException(e.toString()); + HttpURLConnection con = getConnection(url); + rd = new BufferedReader(new InputStreamReader(con.getInputStream())); + sb = new StringBuilder(); + + while ((line = rd.readLine()) != null) { + sb.append(line + '\n'); + } + rd.close(); + return sb.toString(); + } catch (IOException e) { + e.printStackTrace(); + } finally{ + rd = null; } + return ""; } private FileFormat getOutputFormat() { @@ -100,4 +112,48 @@ public class ProxyServlet extends HttpServlet { return FileFormat.PNG; } + private HttpURLConnection getConnection(URL url) throws IOException { + if (url.getProtocol().startsWith("https")) { + HttpsURLConnection con = (HttpsURLConnection) url.openConnection(); + con.setRequestMethod("GET"); + con.setReadTimeout(10000); // 10 seconds + // printHttpsCert(con); + con.connect(); + return con; + } else { + HttpURLConnection con = (HttpURLConnection) url.openConnection(); + con.setRequestMethod("GET"); + con.setReadTimeout(10000); // 10 seconds + con.connect(); + return con; + } + } + + /** + * Debug method used to dump the certificate info + * @param con the https connection + */ + private void printHttpsCert(HttpsURLConnection con) { + if (con != null) { + try { + System.out.println("Response Code : " + con.getResponseCode()); + System.out.println("Cipher Suite : " + con.getCipherSuite()); + System.out.println("\n"); + + Certificate[] certs = con.getServerCertificates(); + for (Certificate cert : certs) { + System.out.println("Cert Type : " + cert.getType()); + System.out.println("Cert Hash Code : " + cert.hashCode()); + System.out.println("Cert Public Key Algorithm : " + cert.getPublicKey().getAlgorithm()); + System.out.println("Cert Public Key Format : " + cert.getPublicKey().getFormat()); + System.out.println("\n"); + } + + } catch (SSLPeerUnverifiedException e) { + e.printStackTrace(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } } diff --git a/src/test/java/net/sourceforge/plantuml/servlet/TestProxy.java b/src/test/java/net/sourceforge/plantuml/servlet/TestProxy.java index 34a1b55..e1804ec 100644 --- a/src/test/java/net/sourceforge/plantuml/servlet/TestProxy.java +++ b/src/test/java/net/sourceforge/plantuml/servlet/TestProxy.java @@ -60,6 +60,6 @@ public class TestProxy extends WebappTestCase { WebRequest request = new GetMethodWebRequest(getServerUrl() + "proxy?src=invalidURL"); WebResponse response = conversation.getResource(request); // Analyze response, it must be HTTP error 500 - assertEquals("Response HTTP status is not 500", response.getResponseCode(), 500); + assertEquals("Bad HTTP status received", 500, response.getResponseCode()); } }