001/*
002 * Forge Mod Loader
003 * Copyright (c) 2012-2013 cpw.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the GNU Lesser Public License v2.1
006 * which accompanies this distribution, and is available at
007 * http://www.gnu.org/licenses/old-licenses/gpl-2.0.html
008 * 
009 * Contributors:
010 *     cpw - implementation
011 */
012
013package cpw.mods.fml.common.asm.transformers;
014
015import java.io.BufferedOutputStream;
016import java.io.BufferedReader;
017import java.io.ByteArrayOutputStream;
018import java.io.DataInputStream;
019import java.io.File;
020import java.io.FileInputStream;
021import java.io.FileNotFoundException;
022import java.io.FileOutputStream;
023import java.io.IOException;
024import java.io.InputStream;
025import java.io.InputStreamReader;
026import java.util.ArrayList;
027import java.util.Collections;
028import java.util.Enumeration;
029import java.util.HashSet;
030import java.util.Hashtable;
031import java.util.LinkedHashSet;
032import java.util.List;
033import java.util.Map.Entry;
034import java.util.zip.ZipEntry;
035import java.util.zip.ZipFile;
036import java.util.zip.ZipOutputStream;
037
038import org.objectweb.asm.ClassReader;
039import org.objectweb.asm.ClassWriter;
040import org.objectweb.asm.Type;
041import org.objectweb.asm.tree.AnnotationNode;
042import org.objectweb.asm.tree.ClassNode;
043import org.objectweb.asm.tree.FieldNode;
044import org.objectweb.asm.tree.MethodNode;
045
046import com.google.common.base.Objects;
047import com.google.common.collect.Lists;
048import com.google.common.collect.Sets;
049
050import cpw.mods.fml.relauncher.Side;
051import cpw.mods.fml.relauncher.SideOnly;
052
053public class MCPMerger
054{
055    private static Hashtable<String, ClassInfo> clients = new Hashtable<String, ClassInfo>();
056    private static Hashtable<String, ClassInfo> shared  = new Hashtable<String, ClassInfo>();
057    private static Hashtable<String, ClassInfo> servers = new Hashtable<String, ClassInfo>();
058    private static HashSet<String> copyToServer = new HashSet<String>();
059    private static HashSet<String> copyToClient = new HashSet<String>();
060    private static HashSet<String> dontAnnotate = new HashSet<String>();
061    private static final boolean DEBUG = false;
062
063    public static void main(String[] args)
064    {
065        if (args.length != 3)
066        {
067            System.out.println("Usage: MCPMerger <MapFile> <minecraft.jar> <minecraft_server.jar>");
068            System.exit(1);
069        }
070
071        File map_file = new File(args[0]);
072        File client_jar = new File(args[1]);
073        File server_jar = new File(args[2]);
074        File client_jar_tmp = new File(args[1] + ".MergeBack");
075        File server_jar_tmp = new File(args[2] + ".MergeBack");
076
077
078        if (client_jar_tmp.exists() && !client_jar_tmp.delete())
079        {
080            System.out.println("Could not delete temp file: " + client_jar_tmp);
081        }
082
083        if (server_jar_tmp.exists() && !server_jar_tmp.delete())
084        {
085            System.out.println("Could not delete temp file: " + server_jar_tmp);
086        }
087
088        if (!client_jar.exists())
089        {
090            System.out.println("Could not find minecraft.jar: " + client_jar);
091            System.exit(1);
092        }
093
094        if (!server_jar.exists())
095        {
096            System.out.println("Could not find minecraft_server.jar: " + server_jar);
097            System.exit(1);
098        }
099
100        if (!client_jar.renameTo(client_jar_tmp))
101        {
102            System.out.println("Could not rename file: " + client_jar + " -> " + client_jar_tmp);
103            System.exit(1);
104        }
105
106        if (!server_jar.renameTo(server_jar_tmp))
107        {
108            System.out.println("Could not rename file: " + server_jar + " -> " + server_jar_tmp);
109            System.exit(1);
110        }
111
112        if (!readMapFile(map_file))
113        {
114            System.out.println("Could not read map file: " + map_file);
115            System.exit(1);
116        }
117
118        try
119        {
120            processJar(client_jar_tmp, server_jar_tmp, client_jar, server_jar);
121        }
122        catch (IOException e)
123        {
124            e.printStackTrace();
125            System.exit(1);
126        }
127
128        if (!client_jar_tmp.delete())
129        {
130            System.out.println("Could not delete temp file: " + client_jar_tmp);
131        }
132
133        if (!server_jar_tmp.delete())
134        {
135            System.out.println("Could not delete temp file: " + server_jar_tmp);
136        }
137    }
138
139    private static boolean readMapFile(File mapFile)
140    {
141        try
142        {
143            FileInputStream fstream = new FileInputStream(mapFile);
144            DataInputStream in = new DataInputStream(fstream);
145            BufferedReader br = new BufferedReader(new InputStreamReader(in));
146
147            String line;
148            while ((line = br.readLine()) != null)
149            {
150                line = line.split("#")[0];
151                char cmd = line.charAt(0);
152                line = line.substring(1).trim();
153                
154                switch (cmd)
155                {
156                    case '!': dontAnnotate.add(line); break;
157                    case '<': copyToClient.add(line); break;
158                    case '>': copyToServer.add(line); break; 
159                }
160            }
161
162            in.close();
163            return true;
164        }
165        catch (Exception e)
166        {
167            System.err.println("Error: " + e.getMessage());
168            return false;
169        }
170    }
171
172    public static void processJar(File clientInFile, File serverInFile, File clientOutFile, File serverOutFile) throws IOException
173    {
174        ZipFile cInJar = null;
175        ZipFile sInJar = null;
176        ZipOutputStream cOutJar = null;
177        ZipOutputStream sOutJar = null;
178
179        try
180        {
181            try
182            {
183                cInJar = new ZipFile(clientInFile);
184                sInJar = new ZipFile(serverInFile);
185            }
186            catch (FileNotFoundException e)
187            {
188                throw new FileNotFoundException("Could not open input file: " + e.getMessage());
189            }
190            try
191            {
192                cOutJar = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(clientOutFile)));
193                sOutJar = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(serverOutFile)));
194            }
195            catch (FileNotFoundException e)
196            {
197                throw new FileNotFoundException("Could not open output file: " + e.getMessage());
198            }
199            Hashtable<String, ZipEntry> cClasses = getClassEntries(cInJar, cOutJar);
200            Hashtable<String, ZipEntry> sClasses = getClassEntries(sInJar, sOutJar);
201            HashSet<String> cAdded = new HashSet<String>();
202            HashSet<String> sAdded = new HashSet<String>();
203
204            for (Entry<String, ZipEntry> entry : cClasses.entrySet())
205            {
206                String name = entry.getKey();
207                ZipEntry cEntry = entry.getValue();
208                ZipEntry sEntry = sClasses.get(name);
209
210                if (sEntry == null)
211                {
212                    if (!copyToServer.contains(name))
213                    {
214                        copyClass(cInJar, cEntry, cOutJar, null, true);
215                        cAdded.add(name);
216                    }
217                    else
218                    {
219                        if (DEBUG)
220                        {
221                            System.out.println("Copy class c->s : " + name);
222                        }
223                        copyClass(cInJar, cEntry, cOutJar, sOutJar, true);
224                        cAdded.add(name);
225                        sAdded.add(name);
226                    }
227                    continue;
228                }
229
230                sClasses.remove(name);
231                ClassInfo info = new ClassInfo(name);
232                shared.put(name, info);
233
234                byte[] cData = readEntry(cInJar, entry.getValue());
235                byte[] sData = readEntry(sInJar, sEntry);
236                byte[] data = processClass(cData, sData, info);
237
238                ZipEntry newEntry = new ZipEntry(cEntry.getName());
239                cOutJar.putNextEntry(newEntry);
240                cOutJar.write(data);
241                sOutJar.putNextEntry(newEntry);
242                sOutJar.write(data);
243                cAdded.add(name);
244                sAdded.add(name);
245            }
246
247            for (Entry<String, ZipEntry> entry : sClasses.entrySet())
248            {
249                if (DEBUG)
250                {
251                    System.out.println("Copy class s->c : " + entry.getKey());
252                }
253                copyClass(sInJar, entry.getValue(), cOutJar, sOutJar, false);
254            }
255
256            for (String name : new String[]{SideOnly.class.getName(), Side.class.getName()})
257            {
258                String eName = name.replace(".", "/");
259                byte[] data = getClassBytes(name);
260                ZipEntry newEntry = new ZipEntry(name.replace(".", "/").concat(".class"));
261                if (!cAdded.contains(eName))
262                {
263                    cOutJar.putNextEntry(newEntry);
264                    cOutJar.write(data);
265                }
266                if (!sAdded.contains(eName))
267                {
268                    sOutJar.putNextEntry(newEntry);
269                    sOutJar.write(data);
270                }
271            }
272
273        }
274        finally
275        {
276            if (cInJar != null)
277            {
278                try { cInJar.close(); } catch (IOException e){}
279            }
280
281            if (sInJar != null)
282            {
283                try { sInJar.close(); } catch (IOException e) {}
284            }
285            if (cOutJar != null)
286            {
287                try { cOutJar.close(); } catch (IOException e){}
288            }
289
290            if (sOutJar != null)
291            {
292                try { sOutJar.close(); } catch (IOException e) {}
293            }
294        }
295    }
296
297    private static void copyClass(ZipFile inJar, ZipEntry entry, ZipOutputStream outJar, ZipOutputStream outJar2, boolean isClientOnly) throws IOException
298    {
299        ClassReader reader = new ClassReader(readEntry(inJar, entry));
300        ClassNode classNode = new ClassNode();
301
302        reader.accept(classNode, 0);
303
304        if (!dontAnnotate.contains(classNode.name))
305        {
306            if (classNode.visibleAnnotations == null) classNode.visibleAnnotations = new ArrayList<AnnotationNode>();
307            classNode.visibleAnnotations.add(getSideAnn(isClientOnly));
308        }
309
310        ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_MAXS);
311        classNode.accept(writer);
312        byte[] data = writer.toByteArray();
313
314        ZipEntry newEntry = new ZipEntry(entry.getName());
315        if (outJar != null)
316        {
317            outJar.putNextEntry(newEntry);
318            outJar.write(data);
319        }
320        if (outJar2 != null)
321        {
322            outJar2.putNextEntry(newEntry);
323            outJar2.write(data);
324        }
325    }
326
327    private static AnnotationNode getSideAnn(boolean isClientOnly)
328    {
329        AnnotationNode ann = new AnnotationNode(Type.getDescriptor(SideOnly.class));
330        ann.values = new ArrayList<Object>();
331        ann.values.add("value");
332        ann.values.add(new String[]{ Type.getDescriptor(Side.class), (isClientOnly ? "CLIENT" : "SERVER")});
333        return ann;
334    }
335
336    @SuppressWarnings("unchecked")
337    private static Hashtable<String, ZipEntry> getClassEntries(ZipFile inFile, ZipOutputStream outFile) throws IOException
338    {
339        Hashtable<String, ZipEntry> ret = new Hashtable<String, ZipEntry>();
340        for (ZipEntry entry : Collections.list((Enumeration<ZipEntry>)inFile.entries()))
341        {
342            if (entry.isDirectory())
343            {
344                outFile.putNextEntry(entry);
345                continue;
346            }
347            String entryName = entry.getName();
348            if (!entryName.endsWith(".class") || entryName.startsWith("."))
349            {
350                ZipEntry newEntry = new ZipEntry(entry.getName());
351                outFile.putNextEntry(newEntry);
352                outFile.write(readEntry(inFile, entry));
353            }
354            else
355            {
356                ret.put(entryName.replace(".class", ""), entry);
357            }
358        }
359        return ret;
360    }
361    private static byte[] readEntry(ZipFile inFile, ZipEntry entry) throws IOException
362    {
363        return readFully(inFile.getInputStream(entry));
364    }
365    private static byte[] readFully(InputStream stream) throws IOException
366    {
367        byte[] data = new byte[4096];
368        ByteArrayOutputStream entryBuffer = new ByteArrayOutputStream();
369        int len;
370        do
371        {
372            len = stream.read(data);
373            if (len > 0)
374            {
375                entryBuffer.write(data, 0, len);
376            }
377        } while (len != -1);
378
379        return entryBuffer.toByteArray();
380    }
381    private static class ClassInfo
382    {
383        public String name;
384        public ArrayList<FieldNode> cField = new ArrayList<FieldNode>();
385        public ArrayList<FieldNode> sField = new ArrayList<FieldNode>();
386        public ArrayList<MethodNode> cMethods = new ArrayList<MethodNode>();
387        public ArrayList<MethodNode> sMethods = new ArrayList<MethodNode>();
388        public ClassInfo(String name){ this.name = name; }
389        public boolean isSame() { return (cField.size() == 0 && sField.size() == 0 && cMethods.size() == 0 && sMethods.size() == 0); }
390    }
391
392    public static byte[] processClass(byte[] cIn, byte[] sIn, ClassInfo info)
393    {
394        ClassNode cClassNode = getClassNode(cIn);
395        ClassNode sClassNode = getClassNode(sIn);
396
397        processFields(cClassNode, sClassNode, info);
398        processMethods(cClassNode, sClassNode, info);
399
400        ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_MAXS);
401        cClassNode.accept(writer);
402        return writer.toByteArray();
403    }
404
405    private static ClassNode getClassNode(byte[] data)
406    {
407        ClassReader reader = new ClassReader(data);
408        ClassNode classNode = new ClassNode();
409        reader.accept(classNode, 0);
410        return classNode;
411    }
412
413    @SuppressWarnings("unchecked")
414    private static void processFields(ClassNode cClass, ClassNode sClass, ClassInfo info)
415    {
416        List<FieldNode> cFields = cClass.fields;
417        List<FieldNode> sFields = sClass.fields;
418
419        int sI = 0;
420        for (int x = 0; x < cFields.size(); x++)
421        {
422            FieldNode cF = cFields.get(x);
423            if (sI < sFields.size())
424            {
425                if (!cF.name.equals(sFields.get(sI).name))
426                {
427                    boolean serverHas = false;
428                    for (int y = sI + 1; y < sFields.size(); y++)
429                    {
430                        if (cF.name.equals(sFields.get(y).name))
431                        {
432                            serverHas = true;
433                            break;
434                        }
435                    }
436                    if (serverHas)
437                    {
438                        boolean clientHas = false;
439                        FieldNode sF = sFields.get(sI);
440                        for (int y = x + 1; y < cFields.size(); y++)
441                        {
442                            if (sF.name.equals(cFields.get(y).name))
443                            {
444                                clientHas = true;
445                                break;
446                            }
447                        }
448                        if (!clientHas)
449                        {
450                            if  (sF.visibleAnnotations == null) sF.visibleAnnotations = new ArrayList<AnnotationNode>();
451                            sF.visibleAnnotations.add(getSideAnn(false));
452                            cFields.add(x++, sF);
453                            info.sField.add(sF);
454                        }
455                    }
456                    else
457                    {
458                        if  (cF.visibleAnnotations == null) cF.visibleAnnotations = new ArrayList<AnnotationNode>();
459                        cF.visibleAnnotations.add(getSideAnn(true));
460                        sFields.add(sI, cF);
461                        info.cField.add(cF);
462                    }
463                }
464            }
465            else
466            {
467                if  (cF.visibleAnnotations == null) cF.visibleAnnotations = new ArrayList<AnnotationNode>();
468                cF.visibleAnnotations.add(getSideAnn(true));
469                sFields.add(sI, cF);
470                info.cField.add(cF);
471            }
472            sI++;
473        }
474        if (sFields.size() != cFields.size())
475        {
476            for (int x = cFields.size(); x < sFields.size(); x++)
477            {
478                FieldNode sF = sFields.get(x);
479                if  (sF.visibleAnnotations == null) sF.visibleAnnotations = new ArrayList<AnnotationNode>();
480                sF.visibleAnnotations.add(getSideAnn(true));
481                cFields.add(x++, sF);
482                info.sField.add(sF);
483            }
484        }
485    }
486
487    private static class MethodWrapper
488    {
489        private MethodNode node;
490        public boolean client;
491        public boolean server;
492        public MethodWrapper(MethodNode node)
493        {
494            this.node = node;
495        }
496        @Override
497        public boolean equals(Object obj)
498        {
499            if (obj == null || !(obj instanceof MethodWrapper)) return false;
500            MethodWrapper mw = (MethodWrapper) obj;
501            boolean eq = Objects.equal(node.name, mw.node.name) && Objects.equal(node.desc, mw.node.desc);
502            if (eq)
503            {
504                mw.client = this.client | mw.client;
505                mw.server = this.server | mw.server;
506                this.client = this.client | mw.client;
507                this.server = this.server | mw.server;
508                if (DEBUG)
509                {
510                    System.out.printf(" eq: %s %s\n", this, mw);
511                }
512            }
513            return eq;
514        }
515
516        @Override
517        public int hashCode()
518        {
519            return Objects.hashCode(node.name, node.desc);
520        }
521        @Override
522        public String toString()
523        {
524            return Objects.toStringHelper(this).add("name", node.name).add("desc",node.desc).add("server",server).add("client",client).toString();
525        }
526    }
527    @SuppressWarnings("unchecked")
528    private static void processMethods(ClassNode cClass, ClassNode sClass, ClassInfo info)
529    {
530        List<MethodNode> cMethods = (List<MethodNode>)cClass.methods;
531        List<MethodNode> sMethods = (List<MethodNode>)sClass.methods;
532        LinkedHashSet<MethodWrapper> allMethods = Sets.newLinkedHashSet();
533
534        int cPos = 0;
535        int sPos = 0;
536        int cLen = cMethods.size();
537        int sLen = sMethods.size();
538        String clientName = "";
539        String lastName = clientName;
540        String serverName = "";
541        while (cPos < cLen || sPos < sLen)
542        {
543            do
544            {
545                if (sPos>=sLen)
546                {
547                    break;
548                }
549                MethodNode sM = sMethods.get(sPos);
550                serverName = sM.name;
551                if (!serverName.equals(lastName) && cPos != cLen)
552                {
553                    if (DEBUG)
554                    {
555                        System.out.printf("Server -skip : %s %s %d (%s %d) %d [%s]\n", sClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
556                    }
557                    break;
558                }
559                MethodWrapper mw = new MethodWrapper(sM);
560                mw.server = true;
561                allMethods.add(mw);
562                if (DEBUG)
563                {
564                    System.out.printf("Server *add* : %s %s %d (%s %d) %d [%s]\n", sClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
565                }
566                sPos++;
567            }
568            while (sPos < sLen);
569            do
570            {
571                if (cPos>=cLen)
572                {
573                    break;
574                }
575                MethodNode cM = cMethods.get(cPos);
576                lastName = clientName;
577                clientName = cM.name;
578                if (!clientName.equals(lastName) && sPos != sLen)
579                {
580                    if (DEBUG)
581                    {
582                        System.out.printf("Client -skip : %s %s %d (%s %d) %d [%s]\n", cClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
583                    }
584                    break;
585                }
586                MethodWrapper mw = new MethodWrapper(cM);
587                mw.client = true;
588                allMethods.add(mw);
589                if (DEBUG)
590                {
591                    System.out.printf("Client *add* : %s %s %d (%s %d) %d [%s]\n", cClass.name, clientName, cLen - cPos, serverName, sLen - sPos, allMethods.size(), lastName);
592                }
593                cPos++;
594            }
595            while (cPos < cLen);
596        }
597
598        cMethods.clear();
599        sMethods.clear();
600
601        for (MethodWrapper mw : allMethods)
602        {
603            if (DEBUG)
604            {
605                System.out.println(mw);
606            }
607            cMethods.add(mw.node);
608            sMethods.add(mw.node);
609            if (mw.server && mw.client)
610            {
611                // no op
612            }
613            else
614            {
615                if (mw.node.visibleAnnotations == null) mw.node.visibleAnnotations = Lists.newArrayListWithExpectedSize(1);
616                mw.node.visibleAnnotations.add(getSideAnn(mw.client));
617                if (mw.client)
618                {
619                    info.sMethods.add(mw.node);
620                }
621                else
622                {
623                    info.cMethods.add(mw.node);
624                }
625            }
626        }
627    }
628
629    public static byte[] getClassBytes(String name) throws IOException
630    {
631        InputStream classStream = null;
632        try
633        {
634            classStream = MCPMerger.class.getResourceAsStream("/" + name.replace('.', '/').concat(".class"));
635            return readFully(classStream);
636        }
637        finally
638        {
639            if (classStream != null)
640            {
641                try
642                {
643                    classStream.close();
644                }
645                catch (IOException e){}
646            }
647        }
648    }
649}