001    /**
002     * Licensed to the Apache Software Foundation (ASF) under one
003     * or more contributor license agreements. See the NOTICE file
004     * distributed with this work for additional information
005     * regarding copyright ownership. The ASF licenses this file
006     * to you under the Apache License, Version 2.0 (the
007     * "License"); you may not use this file except in compliance
008     * with the License. You may obtain a copy of the License at
009     *
010     * http://www.apache.org/licenses/LICENSE-2.0
011     *
012     * Unless required by applicable law or agreed to in writing,
013     * software distributed under the License is distributed on an
014     * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015     * KIND, either express or implied. See the License for the
016     * specific language governing permissions and limitations
017     * under the License.
018     */
019    
020    package org.apache.geronimo.axis2;
021    
022    import java.io.FileNotFoundException;
023    import java.io.OutputStream;
024    import java.net.MalformedURLException;
025    import java.net.URL;
026    import java.util.Collection;
027    import java.util.List;
028    import java.util.Map;
029    import java.util.concurrent.ConcurrentHashMap;
030    
031    import javax.wsdl.Definition;
032    import javax.wsdl.Import;
033    import javax.wsdl.Types;
034    import javax.wsdl.extensions.ExtensibilityElement;
035    import javax.wsdl.extensions.schema.Schema;
036    import javax.wsdl.extensions.schema.SchemaImport;
037    import javax.wsdl.extensions.schema.SchemaReference;
038    import javax.wsdl.factory.WSDLFactory;
039    import javax.wsdl.xml.WSDLReader;
040    import javax.wsdl.xml.WSDLWriter;
041    import javax.xml.transform.OutputKeys;
042    import javax.xml.transform.Source;
043    import javax.xml.transform.Transformer;
044    import javax.xml.transform.TransformerException;
045    import javax.xml.transform.TransformerFactory;
046    import javax.xml.transform.dom.DOMSource;
047    import javax.xml.transform.stream.StreamResult;
048    
049    import org.apache.axis2.description.AxisService;
050    import org.apache.commons.logging.Log;
051    import org.apache.commons.logging.LogFactory;
052    import org.apache.geronimo.jaxws.WSDLUtils;
053    import org.w3c.dom.Element;
054    import org.w3c.dom.Node;
055    import org.w3c.dom.NodeList;
056    
057    public class WSDLQueryHandler {
058    
059        private static final Log LOG = LogFactory.getLog(WSDLQueryHandler.class);
060        
061        private static TransformerFactory transformerFactory = TransformerFactory.newInstance();
062        
063        private Map<String, Definition> mp = new ConcurrentHashMap<String, Definition>();
064        private Map<String, SchemaReference> smp = new ConcurrentHashMap<String, SchemaReference>();
065        private AxisService service;
066        
067        public WSDLQueryHandler(AxisService service) {
068            this.service = service;
069        }
070        
071        public void writeResponse(String baseUri, String wsdlUri, OutputStream os) throws Exception {
072    
073            String base = null;
074            String wsdl = "";
075            String xsd = null;
076            
077            int idx = baseUri.toLowerCase().indexOf("?wsdl");
078            if (idx != -1) {
079                base = baseUri.substring(0, idx);
080                wsdl = baseUri.substring(idx + 5);
081                if (wsdl.length() > 0) {
082                    wsdl = wsdl.substring(1);
083                }
084            } else {
085                idx = baseUri.toLowerCase().indexOf("?xsd");
086                if (idx != -1) {
087                    base = baseUri.substring(0, idx);
088                    xsd = baseUri.substring(idx + 4);
089                    if (xsd.length() > 0) {
090                        xsd = xsd.substring(1);
091                    }
092                } else {
093                    throw new Exception("Invalid request: " + baseUri);
094                }
095            }
096    
097            if (!mp.containsKey(wsdl)) {
098                WSDLFactory factory = WSDLFactory.newInstance();
099                WSDLReader reader = factory.newWSDLReader();
100                reader.setFeature("javax.wsdl.importDocuments", true);
101                reader.setFeature("javax.wsdl.verbose", false);
102                Definition def = reader.readWSDL(wsdlUri);
103                updateDefinition(def, mp, smp, base);
104                // remove other services and ports from wsdl
105                WSDLUtils.trimDefinition(def, this.service.getName(), this.service.getEndpointName());
106                mp.put("", def);
107            }
108    
109            Element rootElement;
110    
111            if (xsd == null) {
112                Definition def = mp.get(wsdl);
113    
114                if (def == null) {
115                    throw new FileNotFoundException("WSDL not found: " + wsdl);
116                }
117                
118                // update service port location on each request
119                if (wsdl.equals("")) {
120                    WSDLUtils.updateLocations(def, base);
121                }
122                
123                WSDLFactory factory = WSDLFactory.newInstance();
124                WSDLWriter writer = factory.newWSDLWriter();
125    
126                rootElement = writer.getDocument(def).getDocumentElement();
127            } else {
128                SchemaReference si = smp.get(xsd);
129                
130                if (si == null) {
131                    throw new FileNotFoundException("Schema not found: " + xsd);
132                }
133                
134                rootElement = si.getReferencedSchema().getElement();
135            }
136    
137            NodeList nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema",
138                    "import");
139            for (int x = 0; x < nl.getLength(); x++) {
140                Element el = (Element) nl.item(x);
141                String sl = el.getAttribute("schemaLocation");
142                if (smp.containsKey(sl)) {
143                    el.setAttribute("schemaLocation", base + "?xsd=" + sl);
144                }
145            }
146            nl = rootElement.getElementsByTagNameNS("http://www.w3.org/2001/XMLSchema", "include");
147            for (int x = 0; x < nl.getLength(); x++) {
148                Element el = (Element) nl.item(x);
149                String sl = el.getAttribute("schemaLocation");
150                if (smp.containsKey(sl)) {
151                    el.setAttribute("schemaLocation", base + "?xsd=" + sl);
152                }
153            }
154            nl = rootElement.getElementsByTagNameNS("http://schemas.xmlsoap.org/wsdl/", "import");
155            for (int x = 0; x < nl.getLength(); x++) {
156                Element el = (Element) nl.item(x);
157                String sl = el.getAttribute("location");
158                if (mp.containsKey(sl)) {
159                    el.setAttribute("location", base + "?wsdl=" + sl);
160                }
161            }
162    
163            writeTo(rootElement, os);
164        }
165           
166        protected void updateDefinition(Definition def,
167                                        Map<String, Definition> done,
168                                        Map<String, SchemaReference> doneSchemas,
169                                        String base) {
170            Collection<List> imports = def.getImports().values();
171            for (List lst : imports) {
172                List<Import> impLst = lst;
173                for (Import imp : impLst) {
174                    String start = imp.getLocationURI();
175                    try {
176                        //check to see if it's aleady in a URL format.  If so, leave it.
177                        new URL(start);
178                    } catch (MalformedURLException e) {
179                        done.put(start, imp.getDefinition());
180                        updateDefinition(imp.getDefinition(), done, doneSchemas, base);
181                    }
182                }
183            }      
184            
185            
186            /* This doesn't actually work.   Setting setSchemaLocationURI on the import
187            * for some reason doesn't actually result in the new URI being written
188            * */
189            Types types = def.getTypes();
190            if (types != null) {
191                for (ExtensibilityElement el : (List<ExtensibilityElement>)types.getExtensibilityElements()) {
192                    if (el instanceof Schema) {
193                        Schema see = (Schema)el;
194                        updateSchemaImports(see, doneSchemas, base);
195                    }
196                }
197            }
198        }
199        
200        protected void updateSchemaImports(Schema schema,
201                                           Map<String, SchemaReference> doneSchemas,
202                                           String base) {
203            Collection<List>  imports = schema.getImports().values();
204            for (List lst : imports) {
205                List<SchemaImport> impLst = lst;
206                for (SchemaImport imp : impLst) {
207                    String start = imp.getSchemaLocationURI();
208                    if (start != null) {
209                        try {
210                            //check to see if it's aleady in a URL format.  If so, leave it.
211                            new URL(start);
212                        } catch (MalformedURLException e) {
213                            if (!doneSchemas.containsKey(start)) {
214                                doneSchemas.put(start, imp);
215                                updateSchemaImports(imp.getReferencedSchema(), doneSchemas, base);
216                            }
217                        }
218                    }
219                }
220            }
221            List<SchemaReference> includes = schema.getIncludes();
222            for (SchemaReference included : includes) {
223                String start = included.getSchemaLocationURI();
224                if (start != null) {
225                    try {
226                        //check to see if it's aleady in a URL format.  If so, leave it.
227                        new URL(start);
228                    } catch (MalformedURLException e) {
229                        if (!doneSchemas.containsKey(start)) {
230                            doneSchemas.put(start, included);
231                            updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
232                        }
233                    }
234                }
235            }
236        }
237        
238        public static void writeTo(Node node, OutputStream os) {
239            writeTo(new DOMSource(node), os);
240        }
241        
242        public static void writeTo(Source src, OutputStream os) {
243            Transformer it;
244            try {
245                it = transformerFactory.newTransformer();
246                it.setOutputProperty(OutputKeys.METHOD, "xml");
247                it.setOutputProperty(OutputKeys.INDENT, "yes");
248                it.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "4");
249                it.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "no");
250                it.setOutputProperty(OutputKeys.ENCODING, "utf-8");
251                it.transform(src, new StreamResult(os));
252            } catch (TransformerException e) {
253                // TODO Auto-generated catch block
254                e.printStackTrace();
255            }
256        }
257        
258    }