From: www Date: Sun, 29 Sep 2024 21:29:59 +0000 (+0000) Subject: Mirrored from suika.git X-Git-Url: https://git.chaotic.ninja/gitweb/yakumo_izuru/?a=commitdiff_plain;ds=sidebyside;p=suika.git Mirrored from suika.git git-svn-id: https://svn.chaotic.ninja/svn/suika-yakumo.izuru@1 dd155a68-181b-fa42-919d-ba26d1f5f0f2 --- 555d1ff2b86846ca7154687b1f70b04e1173af29 diff --git a/branches/master/.gitignore b/branches/master/.gitignore new file mode 100644 index 0000000..4aa22f5 --- /dev/null +++ b/branches/master/.gitignore @@ -0,0 +1,5 @@ +vendor +/suika +/suikadb +/suika-znc-import +/suika.db diff --git a/branches/master/LICENSE b/branches/master/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/branches/master/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/branches/master/Makefile b/branches/master/Makefile new file mode 100644 index 0000000..a77bc97 --- /dev/null +++ b/branches/master/Makefile @@ -0,0 +1,45 @@ +GO ?= go +RM ?= rm +GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor +PREFIX ?= /usr/local +BINDIR ?= bin +MANDIR ?= share/man +MKDIR ?= mkdir +CP ?= cp +SYSCONFDIR ?= /etc +ASCIIDOCTOR ?= asciidoctor + +VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"` +COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"` +BRANCH = `git rev-parse --abbrev-ref HEAD` +BUILD = `git show -s --pretty=format:%cI` + +GOARCH ?= amd64 +GOOS ?= linux + +all: build + +build: vendor + ${GO} build ${GOFLAGS} ./cmd/suika + ${GO} build ${GOFLAGS} ./cmd/suikadb + ${GO} build ${GOFLAGS} ./cmd/suika-znc-import +clean: + ${RM} -f suika suikadb suika-znc-import +install: + ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR} + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7 + ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika + ${MKDIR} -p ${DESTDIR}/var/lib/suika + ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR} + ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1 + ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5 + [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config +test: + go test +vendor: + go mod vendor +.PHONY: build clean install diff --git a/branches/master/README.md b/branches/master/README.md new file mode 100644 index 0000000..2171ba5 --- /dev/null +++ b/branches/master/README.md @@ -0,0 +1,33 @@ +# suika + +[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika) + +A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power) + +- Multi-user +- Support multiple clients for a single user, with proper backlog + synchronization +- Support connecting to multiple upstream servers via a single IRC connection + to the bouncer + +## Building and installing + +Dependencies: + +- Go +- BSD or GNU make + +For end users, a `Makefile` is provided: + + make + doas make install + +For development, you can use `go run ./cmd/suika` as usual. + +## License +AGPLv3, see [LICENSE](LICENSE). + +* Copyright (C) 2020 The soju Contributors +* Copyright (C) 2023-present Izuru Yakumo + +The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT diff --git a/branches/master/bridge.go b/branches/master/bridge.go new file mode 100644 index 0000000..7bb7aa8 --- /dev/null +++ b/branches/master/bridge.go @@ -0,0 +1,116 @@ +package suika + +import ( + "context" + "fmt" + "strconv" + "strings" + + "gopkg.in/irc.v3" +) + +func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) { + if !ch.complete { + panic("Tried to forward a partial channel") + } + + // RPL_NOTOPIC shouldn't be sent on JOIN + if ch.Topic != "" { + sendTopic(dc, ch) + } + + if dc.caps["soju.im/read"] { + channelCM := ch.conn.network.casemap(ch.Name) + r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err) + } else { + timestampStr := "*" + if r != nil { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr}, + }) + } + } + + sendNames(dc, ch) +} + +func sendTopic(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + if ch.Topic != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_TOPIC, + Params: []string{dc.nick, downstreamName, ch.Topic}, + }) + if ch.TopicWho != nil { + topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho) + topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime}, + }) + } + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NOTOPIC, + Params: []string{dc.nick, downstreamName, "No topic is set"}, + }) + } +} + +func sendNames(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + emptyNameReply := &irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, ""}, + } + maxLength := maxMessageLength - len(emptyNameReply.String()) + + var buf strings.Builder + for _, entry := range ch.Members.innerMap { + nick := entry.originalKey + memberships := entry.value.(*memberships) + s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick) + + n := buf.Len() + 1 + len(s) + if buf.Len() != 0 && n > maxLength { + // There's not enough space for the next space + nick. + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + buf.Reset() + } + + if buf.Len() != 0 { + buf.WriteByte(' ') + } + buf.WriteString(s) + } + + if buf.Len() != 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, downstreamName, "End of /NAMES list"}, + }) +} diff --git a/branches/master/certfp.go b/branches/master/certfp.go new file mode 100644 index 0000000..9180f7c --- /dev/null +++ b/branches/master/certfp.go @@ -0,0 +1,72 @@ +package suika + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" +) + +func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) { + var ( + privKey crypto.PrivateKey + pubKey crypto.PublicKey + ) + switch keyType { + case "rsa": + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ecdsa": + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ed25519": + var err error + pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + } + + // Using PKCS#8 allows easier extension for new key types. + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey) + if err != nil { + return nil, nil, err + } + + notBefore := time.Now() + // Lets make a fair assumption nobody will use the same cert for more than 20 years... + notAfter := notBefore.Add(24 * time.Hour * 365 * 20) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, err + } + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: "suika auto-generated certificate"}, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey) + if err != nil { + return nil, nil, err + } + + return privKeyBytes, certBytes, nil +} diff --git a/branches/master/cmd/suika-znc-import/main.go b/branches/master/cmd/suika-znc-import/main.go new file mode 100644 index 0000000..78fefc6 --- /dev/null +++ b/branches/master/cmd/suika-znc-import/main.go @@ -0,0 +1,474 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "net/url" + "os" + "strings" + "unicode" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +const usage = `usage: suika-znc-import [options...] + +Imports configuration from a ZNC file. Users and networks are merged if they +already exist in the suika database. ZNC settings overwrite existing suika +settings. + +Options: + + -help Show this help message + -config Path to suika config file + -user Limit import to username (may be specified multiple times) + -network Limit import to network (may be specified multiple times) +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + users := make(map[string]bool) + networks := make(map[string]bool) + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Var((*stringSetFlag)(&users), "user", "") + flag.Var((*stringSetFlag)(&networks), "network", "") + flag.Parse() + + zncPath := flag.Arg(0) + if zncPath == "" { + flag.Usage() + os.Exit(1) + } + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + ctx := context.Background() + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + f, err := os.Open(zncPath) + if err != nil { + log.Fatalf("failed to open ZNC configuration file: %v", err) + } + defer f.Close() + + zp := zncParser{bufio.NewReader(f), 1} + root, err := zp.sectionBody("", "") + if err != nil { + log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err) + } + + l, err := db.ListUsers(ctx) + if err != nil { + log.Fatalf("failed to list users in DB: %v", err) + } + existingUsers := make(map[string]*suika.User, len(l)) + for i, u := range l { + existingUsers[u.Username] = &l[i] + } + + usersCreated := 0 + usersImported := 0 + networksImported := 0 + channelsImported := 0 + root.ForEach("User", func(section *zncSection) { + username := section.Name + if len(users) > 0 && !users[username] { + return + } + usersImported++ + + u, ok := existingUsers[username] + if ok { + log.Printf("user %q: updating existing user", username) + } else { + // "!!" is an invalid crypt format, thus disables password auth + u = &suika.User{Username: username, Password: "!!"} + usersCreated++ + log.Printf("user %q: creating new user", username) + } + + u.Admin = section.Values.Get("Admin") == "true" + + if err := db.StoreUser(ctx, u); err != nil { + log.Fatalf("failed to store user %q: %v", username, err) + } + userID := u.ID + + l, err := db.ListNetworks(ctx, userID) + if err != nil { + log.Fatalf("failed to list networks for user %q: %v", username, err) + } + existingNetworks := make(map[string]*suika.Network, len(l)) + for i, n := range l { + existingNetworks[n.GetName()] = &l[i] + } + + nick := section.Values.Get("Nick") + realname := section.Values.Get("RealName") + ident := section.Values.Get("Ident") + + section.ForEach("Network", func(section *zncSection) { + netName := section.Name + if len(networks) > 0 && !networks[netName] { + return + } + networksImported++ + + logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName) + logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix) + + netNick := section.Values.Get("Nick") + if netNick == "" { + netNick = nick + } + netRealname := section.Values.Get("RealName") + if netRealname == "" { + netRealname = realname + } + netIdent := section.Values.Get("Ident") + if netIdent == "" { + netIdent = ident + } + + for _, name := range section.Values["LoadModule"] { + switch name { + case "sasl": + logger.Printf("warning: SASL credentials not imported") + case "nickserv": + logger.Printf("warning: NickServ credentials not imported") + case "perform": + logger.Printf("warning: \"perform\" plugin commands not imported") + } + } + + u, pass, err := importNetworkServer(section.Values.Get("Server")) + if err != nil { + logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err) + } + + n, ok := existingNetworks[netName] + if ok { + logger.Printf("updating existing network") + } else { + n = &suika.Network{Name: netName} + logger.Printf("creating new network") + } + + n.Addr = u.String() + n.Nick = netNick + n.Username = netIdent + n.Realname = netRealname + n.Pass = pass + n.Enabled = section.Values.Get("IRCConnectEnabled") != "false" + + if err := db.StoreNetwork(ctx, userID, n); err != nil { + logger.Fatalf("failed to store network: %v", err) + } + + l, err := db.ListChannels(ctx, n.ID) + if err != nil { + logger.Fatalf("failed to list channels: %v", err) + } + existingChannels := make(map[string]*suika.Channel, len(l)) + for i, ch := range l { + existingChannels[ch.Name] = &l[i] + } + + section.ForEach("Chan", func(section *zncSection) { + chName := section.Name + + if section.Values.Get("Disabled") == "true" { + logger.Printf("skipping import of disabled channel %q", chName) + return + } + + channelsImported++ + + ch, ok := existingChannels[chName] + if ok { + logger.Printf("channel %q: updating existing channel", chName) + } else { + ch = &suika.Channel{Name: chName} + logger.Printf("channel %q: creating new channel", chName) + } + + ch.Key = section.Values.Get("Key") + ch.Detached = section.Values.Get("Detached") == "true" + + if err := db.StoreChannel(ctx, n.ID, ch); err != nil { + logger.Printf("channel %q: failed to store channel: %v", chName, err) + } + }) + }) + }) + + if err := db.Close(); err != nil { + log.Printf("failed to close database: %v", err) + } + + if usersCreated > 0 { + log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password `") + } + + log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported) +} + +func importNetworkServer(s string) (u *url.URL, pass string, err error) { + parts := strings.Fields(s) + if len(parts) < 2 { + return nil, "", fmt.Errorf("expected space-separated host and port") + } + + scheme := "irc" + host := parts[0] + port := parts[1] + if strings.HasPrefix(port, "+") { + port = port[1:] + scheme = "ircs" + } + + if len(parts) > 2 { + pass = parts[2] + } + + u = &url.URL{ + Scheme: scheme, + Host: host + ":" + port, + } + return u, pass, nil +} + +type zncSection struct { + Type string + Name string + Values zncValues + Children []zncSection +} + +func (s *zncSection) ForEach(typ string, f func(*zncSection)) { + for _, section := range s.Children { + if section.Type == typ { + f(§ion) + } + } +} + +type zncValues map[string][]string + +func (zv zncValues) Get(k string) string { + if len(zv[k]) == 0 { + return "" + } + return zv[k][0] +} + +type zncParser struct { + br *bufio.Reader + line int +} + +func (zp *zncParser) readByte() (byte, error) { + b, err := zp.br.ReadByte() + if b == '\n' { + zp.line++ + } + return b, err +} + +func (zp *zncParser) readRune() (rune, int, error) { + r, n, err := zp.br.ReadRune() + if r == '\n' { + zp.line++ + } + return r, n, err +} + +func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) { + section := &zncSection{Type: typ, Name: name, Values: make(zncValues)} + +Loop: + for { + if err := zp.skipSpace(); err != nil { + return nil, err + } + + b, err := zp.br.Peek(2) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + switch b[0] { + case '<': + if b[1] == '/' { + break Loop + } else { + childType, childName, err := zp.sectionHeader() + if err != nil { + return nil, err + } + child, err := zp.sectionBody(childType, childName) + if err != nil { + return nil, err + } + if footerType, err := zp.sectionFooter(); err != nil { + return nil, err + } else if footerType != childType { + return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType) + } + section.Children = append(section.Children, *child) + } + case '/': + if b[1] == '/' { + if err := zp.skipComment(); err != nil { + return nil, err + } + break + } + fallthrough + default: + k, v, err := zp.keyValuePair() + if err != nil { + return nil, err + } + section.Values[k] = append(section.Values[k], v) + } + } + + return section, nil +} + +func (zp *zncParser) skipSpace() error { + for { + r, _, err := zp.readRune() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if !unicode.IsSpace(r) { + zp.br.UnreadRune() + return nil + } + } +} + +func (zp *zncParser) skipComment() error { + if err := zp.expectRune('/'); err != nil { + return err + } + if err := zp.expectRune('/'); err != nil { + return err + } + + for { + b, err := zp.readByte() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if b == '\n' { + return nil + } + } +} + +func (zp *zncParser) sectionHeader() (string, string, error) { + if err := zp.expectRune('<'); err != nil { + return "", "", err + } + typ, err := zp.readWord(' ') + if err != nil { + return "", "", err + } + name, err := zp.readWord('>') + return typ, name, err +} + +func (zp *zncParser) sectionFooter() (string, error) { + if err := zp.expectRune('<'); err != nil { + return "", err + } + if err := zp.expectRune('/'); err != nil { + return "", err + } + return zp.readWord('>') +} + +func (zp *zncParser) keyValuePair() (string, string, error) { + k, err := zp.readWord('=') + if err != nil { + return "", "", err + } + v, err := zp.readWord('\n') + return strings.TrimSpace(k), strings.TrimSpace(v), err +} + +func (zp *zncParser) expectRune(expected rune) error { + r, _, err := zp.readRune() + if err != nil { + return err + } else if r != expected { + return fmt.Errorf("expected %q, got %q", expected, r) + } + return nil +} + +func (zp *zncParser) readWord(delim byte) (string, error) { + var sb strings.Builder + for { + b, err := zp.readByte() + if err != nil { + return "", err + } + + if b == delim { + return sb.String(), nil + } + if b == '\n' { + return "", fmt.Errorf("expected %q before newline", delim) + } + + sb.WriteByte(b) + } +} + +type stringSetFlag map[string]bool + +func (v *stringSetFlag) String() string { + return fmt.Sprint(map[string]bool(*v)) +} + +func (v *stringSetFlag) Set(s string) error { + (*v)[s] = true + return nil +} diff --git a/branches/master/cmd/suika/main.go b/branches/master/cmd/suika/main.go new file mode 100644 index 0000000..9661859 --- /dev/null +++ b/branches/master/cmd/suika/main.go @@ -0,0 +1,229 @@ +package main + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/url" + "os" + "os/signal" + "strings" + "sync/atomic" + "syscall" + "time" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +// TCP keep-alive interval for downstream TCP connections +const downstreamKeepAlive = 1 * time.Hour + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +func bumpOpenedFileLimit() error { + var rlimit syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err) + } + rlimit.Cur = rlimit.Max + if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err) + } + return nil +} + +var ( + configPath string + debug bool + + tlsCert atomic.Value // *tls.Certificate +) + +func loadConfig() (*config.Server, *suika.Config, error) { + var raw *config.Server + if configPath != "" { + var err error + raw, err = config.Load(configPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load config file: %v", err) + } + } else { + raw = config.Defaults() + } + + var motd string + if raw.MOTDPath != "" { + b, err := os.ReadFile(raw.MOTDPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load MOTD: %v", err) + } + motd = strings.TrimSuffix(string(b), "\n") + } + + if raw.TLS != nil { + cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err) + } + tlsCert.Store(&cert) + } + + cfg := &suika.Config{ + Hostname: raw.Hostname, + Title: raw.Title, + LogPath: raw.LogPath, + MaxUserNetworks: raw.MaxUserNetworks, + MultiUpstream: raw.MultiUpstream, + UpstreamUserIPs: raw.UpstreamUserIPs, + MOTD: motd, + } + return raw, cfg, nil +} + +func main() { + var listen []string + flag.Var((*stringSliceFlag)(&listen), "listen", "listening address") + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.BoolVar(&debug, "debug", false, "enable debug logging") + flag.Parse() + + cfg, serverCfg, err := loadConfig() + if err != nil { + log.Fatal(err) + } + + cfg.Listen = append(cfg.Listen, listen...) + if len(cfg.Listen) == 0 { + cfg.Listen = []string{":6667"} + } + + if err := bumpOpenedFileLimit(); err != nil { + log.Printf("failed to bump max number of opened files: %v", err) + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + var tlsCfg *tls.Config + if cfg.TLS != nil { + tlsCfg = &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return tlsCert.Load().(*tls.Certificate), nil + }, + } + } + + srv := suika.NewServer(db) + srv.SetConfig(serverCfg) + srv.Logger = suika.NewLogger(log.Writer(), debug) + + for _, listen := range cfg.Listen { + listen := listen // copy + listenURI := listen + if !strings.Contains(listenURI, ":/") { + // This is a raw domain name, make it an URL with an empty scheme + listenURI = "//" + listenURI + } + u, err := url.Parse(listenURI) + if err != nil { + log.Fatalf("failed to parse listen URI %q: %v", listen, err) + } + + switch u.Scheme { + case "ircs", "": + if tlsCfg == nil { + log.Fatalf("failed to listen on %q: missing TLS configuration", listen) + } + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6697" + } + ircsTLSCfg := tlsCfg.Clone() + ircsTLSCfg.NextProtos = []string{"irc"} + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + l, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start TLS listener on %q: %v", listen, err) + } + ln := tls.NewListener(l, ircsTLSCfg) + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "irc": + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6667" + } + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + ln, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "unix": + ln, err := net.Listen("unix", u.Path) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + default: + log.Fatalf("failed to listen on %q: unsupported scheme", listen) + } + + log.Printf("starting suika version %v\n", suika.FullVersion()) + log.Printf("server listening on %q", listen) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + if err := srv.Start(); err != nil { + log.Fatal(err) + } + + for sig := range sigCh { + switch sig { + case syscall.SIGHUP: + log.Print("reloading configuration") + _, serverCfg, err := loadConfig() + if err != nil { + log.Printf("failed to reloading configuration: %v", err) + } else { + srv.SetConfig(serverCfg) + } + case syscall.SIGINT, syscall.SIGTERM: + log.Print("shutting down server") + srv.Shutdown() + return + } + } +} diff --git a/branches/master/cmd/suikadb/main.go b/branches/master/cmd/suikadb/main.go new file mode 100644 index 0000000..4d54481 --- /dev/null +++ b/branches/master/cmd/suikadb/main.go @@ -0,0 +1,150 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "os" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" +) + +const usage = `usage: suikadb [-config path] [options...] + + create-user [-admin] Create a new user + change-password Change password for a user + help Show this help message +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Parse() + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + ctx := context.Background() + + switch cmd := flag.Arg(0); cmd { + case "create-user": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + fs := flag.NewFlagSet("", flag.ExitOnError) + admin := fs.Bool("admin", false, "make the new user admin") + fs.Parse(flag.Args()[2:]) + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user := suika.User{ + Username: username, + Password: string(hashed), + Admin: *admin, + } + if err := db.StoreUser(ctx, &user); err != nil { + log.Fatalf("failed to create user: %v", err) + } + case "change-password": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + user, err := db.GetUser(ctx, username) + if err != nil { + log.Fatalf("failed to get user: %v", err) + } + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user.Password = string(hashed) + if err := db.StoreUser(ctx, user); err != nil { + log.Fatalf("failed to update password: %v", err) + } + case "version": + fmt.Printf("%v\n", suika.FullVersion()) + default: + flag.Usage() + if cmd != "help" { + os.Exit(1) + } + } +} + +func readPassword() ([]byte, error) { + var password []byte + var err error + fd := int(os.Stdin.Fd()) + + if term.IsTerminal(fd) { + fmt.Printf("Password: ") + password, err = term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + return nil, err + } + fmt.Printf("\n") + } else { + fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n") + // TODO: the buffering messes up repeated calls to readPassword + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, err + } + return nil, io.ErrUnexpectedEOF + } + password = scanner.Bytes() + + if len(password) == 0 { + return nil, fmt.Errorf("zero length password") + } + } + + return password, nil +} diff --git a/branches/master/config.in b/branches/master/config.in new file mode 100644 index 0000000..0b1ed4d --- /dev/null +++ b/branches/master/config.in @@ -0,0 +1,2 @@ +db sqlite3 /var/lib/suika/main.db +log fs /var/lib/suika/logs/ diff --git a/branches/master/config/config.go b/branches/master/config/config.go new file mode 100644 index 0000000..94d4584 --- /dev/null +++ b/branches/master/config/config.go @@ -0,0 +1,142 @@ +package config + +import ( + "fmt" + "net" + "os" + "strconv" + + "git.sr.ht/~emersion/go-scfg" +) + +type TLS struct { + CertPath, KeyPath string +} + +type Server struct { + Listen []string + TLS *TLS + Hostname string + Title string + MOTDPath string + + SQLDriver string + SQLSource string + LogPath string + + MaxUserNetworks int + MultiUpstream bool + UpstreamUserIPs []*net.IPNet +} + +func Defaults() *Server { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + return &Server{ + Hostname: hostname, + SQLDriver: "sqlite3", + SQLSource: "suika.db", + MaxUserNetworks: -1, + MultiUpstream: true, + } +} + +func Load(path string) (*Server, error) { + cfg, err := scfg.Load(path) + if err != nil { + return nil, err + } + return parse(cfg) +} + +func parse(cfg scfg.Block) (*Server, error) { + srv := Defaults() + for _, d := range cfg { + switch d.Name { + case "listen": + var uri string + if err := d.ParseParams(&uri); err != nil { + return nil, err + } + srv.Listen = append(srv.Listen, uri) + case "hostname": + if err := d.ParseParams(&srv.Hostname); err != nil { + return nil, err + } + case "title": + if err := d.ParseParams(&srv.Title); err != nil { + return nil, err + } + case "motd": + if err := d.ParseParams(&srv.MOTDPath); err != nil { + return nil, err + } + case "tls": + tls := &TLS{} + if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil { + return nil, err + } + srv.TLS = tls + case "db": + if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil { + return nil, err + } + case "log": + var driver string + if err := d.ParseParams(&driver, &srv.LogPath); err != nil { + return nil, err + } + if driver != "fs" { + return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver) + } + case "max-user-networks": + var max string + if err := d.ParseParams(&max); err != nil { + return nil, err + } + var err error + if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + case "multi-upstream-mode": + var str string + if err := d.ParseParams(&str); err != nil { + return nil, err + } + v, err := strconv.ParseBool(str) + if err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + srv.MultiUpstream = v + case "upstream-user-ip": + if len(srv.UpstreamUserIPs) > 0 { + return nil, fmt.Errorf("directive %q: can only be specified once", d.Name) + } + var hasIPv4, hasIPv6 bool + for _, s := range d.Params { + _, n, err := net.ParseCIDR(s) + if err != nil { + return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err) + } + if n.IP.To4() == nil { + if hasIPv6 { + return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name) + } + hasIPv6 = true + } else { + if hasIPv4 { + return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name) + } + hasIPv4 = true + } + srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) + } + default: + return nil, fmt.Errorf("unknown directive %q", d.Name) + } + } + + return srv, nil +} diff --git a/branches/master/conn.go b/branches/master/conn.go new file mode 100644 index 0000000..bb95e25 --- /dev/null +++ b/branches/master/conn.go @@ -0,0 +1,180 @@ +package suika + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "golang.org/x/time/rate" + "gopkg.in/irc.v3" +) + +// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on +// reading and writing IRC messages. +type ircConn interface { + ReadMessage() (*irc.Message, error) + WriteMessage(*irc.Message) error + Close() error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error + RemoteAddr() net.Addr + LocalAddr() net.Addr +} + +func newNetIRCConn(c net.Conn) ircConn { + type netConn net.Conn + return struct { + *irc.Conn + netConn + }{irc.NewConn(c), c} +} + +type connOptions struct { + Logger Logger + RateLimitDelay time.Duration + RateLimitBurst int +} + +type conn struct { + conn ircConn + srv *Server + logger Logger + + lock sync.Mutex + outgoing chan<- *irc.Message + closed bool + closedCh chan struct{} +} + +func newConn(srv *Server, ic ircConn, options *connOptions) *conn { + outgoing := make(chan *irc.Message, 64) + c := &conn{ + conn: ic, + srv: srv, + outgoing: outgoing, + logger: options.Logger, + closedCh: make(chan struct{}), + } + + go func() { + ctx, cancel := c.NewContext(context.Background()) + defer cancel() + + rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst) + for msg := range outgoing { + if err := rl.Wait(ctx); err != nil { + break + } + + c.logger.Debugf("sent: %v", msg) + c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := c.conn.WriteMessage(msg); err != nil { + c.logger.Printf("failed to write message: %v", err) + break + } + } + if err := c.conn.Close(); err != nil && !isErrClosed(err) { + c.logger.Printf("failed to close connection: %v", err) + } else { + c.logger.Debugf("connection closed") + } + // Drain the outgoing channel to prevent SendMessage from blocking + for range outgoing { + // This space is intentionally left blank + } + }() + + c.logger.Debugf("new connection") + return c +} + +func (c *conn) isClosed() bool { + c.lock.Lock() + defer c.lock.Unlock() + return c.closed +} + +// Close closes the connection. It is safe to call from any goroutine. +func (c *conn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return fmt.Errorf("connection already closed") + } + + err := c.conn.Close() + c.closed = true + close(c.outgoing) + close(c.closedCh) + return err +} + +func (c *conn) ReadMessage() (*irc.Message, error) { + msg, err := c.conn.ReadMessage() + if isErrClosed(err) { + return nil, io.EOF + } else if err != nil { + return nil, err + } + + c.logger.Debugf("received: %v", msg) + return msg, nil +} + +// SendMessage queues a new outgoing message. It is safe to call from any +// goroutine. +// +// If the connection is closed before the message is sent, SendMessage silently +// drops the message. +func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return + } + + select { + case c.outgoing <- msg: + // Success + case <-ctx.Done(): + c.logger.Printf("failed to send message: %v", ctx.Err()) + } +} + +func (c *conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// NewContext returns a copy of the parent context with a new Done channel. The +// returned context's Done channel is closed when the connection is closed, +// when the returned cancel function is called, or when the parent context's +// Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(parent) + + go func() { + defer cancel() + + select { + case <-ctx.Done(): + // The parent context has been cancelled, or the caller has called + // cancel() + case <-c.closedCh: + // The connection has been closed + } + }() + + return ctx, cancel +} diff --git a/branches/master/contrib/casemap-logs.sh b/branches/master/contrib/casemap-logs.sh new file mode 100644 index 0000000..709d581 --- /dev/null +++ b/branches/master/contrib/casemap-logs.sh @@ -0,0 +1,42 @@ +#!/bin/sh -eu + +# Converts a log dir to its case-mapped form. +# +# suika needs to be stopped for this script to work properly. The script may +# re-order messages that happened within the same second interval if merging +# two daily log files is necessary. +# +# usage: casemap-logs.sh + +root="$1" + +for net_dir in "$root"/*/*; do + for chan in $(ls "$net_dir"); do + cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')" + if [ "$chan" = "$cm_chan" ]; then + continue + fi + + if ! [ -d "$net_dir/$cm_chan" ]; then + echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + mv "$net_dir/$chan" "$net_dir/$cm_chan" + continue + fi + + echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + for day in $(ls "$net_dir/$chan"); do + if ! [ -e "$net_dir/$cm_chan/$day" ]; then + echo >&2 " Moving log file: '$day'" + mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" + continue + fi + + echo >&2 " Merging log file: '$day'" + sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new" + mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day" + rm "$net_dir/$chan/$day" + done + + rmdir "$net_dir/$chan" + done +done diff --git a/branches/master/contrib/clients.md b/branches/master/contrib/clients.md new file mode 100644 index 0000000..1f8881d --- /dev/null +++ b/branches/master/contrib/clients.md @@ -0,0 +1,100 @@ +# Clients + +This page describes how to configure IRC clients to better integrate with soju. + +Also see the [IRCv3 support tables] for a more general list of clients. + +# catgirl + +catgirl doesn't properly implement cap-3.2, so many capabilities will be +disabled. catgirl developers have publicly stated that supporting bouncers such +as soju is a non-goal. + +# [Emacs] + +There are two clients provided with Emacs. They require some setup to work +properly. + +## Erc + +You need to explicitly set the username, which is the defcustom +`erc-email-userid`. + +```elisp +(setq erc-email-userid "/irc.libera.chat") ;; Example with Libera.Chat +(defun run-erc () + (interactive) + (erc-tls :server "" + :port 6697 + :nick "" + :password "")) +``` + +Then run `M-x run-erc`. + +## Rcirc + +The only thing needed here is the general config: + +```elisp +(setq rcirc-server-alist + '(("" + :port 6697 + :encryption tls + :nick "" + :user-name "/irc.libera.chat" ;; Example with Libera.Chat + :password ""))) +``` + +Then run `M-x irc`. + +# [gamja] + +gamja has been designed together with soju, so should have excellent +integration. gamja supports many IRCv3 features including chat history. +gamja also provides UI to manage soju networks via the +`soju.im/bouncer-networks` extension. + +# [goguma] + +Much like gamja, goguma has been designed together with soju, so should have +excellent integration. goguma supports many IRCv3 features including chat +history. goguma should seamlessly connect to all networks configured in soju via +the `soju.im/bouncer-networks` extension. + +# [Hexchat] + +Hexchat has support for a small set of IRCv3 capabilities. To prevent +automatically reconnecting to channels parted from soju, and prevent buffering +outgoing messages: + + /set irc_reconnect_rejoin off + /set net_throttle off + +# [senpai] + +senpai is being developed with soju in mind, so should have excellent +integration. senpai supports many IRCv3 features including chat history. + +# [Weechat] + +A [Weechat script] is available to provide better integration with soju. +The script will automatically connect to all of your networks once a +single connection to soju is set up in Weechat. + +On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them: + + /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names + /save + /reconnect -all + +See `/help cap` for more information. + +[IRCv3 support tables]: https://ircv3.net/software/clients +[gamja]: https://sr.ht/~emersion/gamja/ +[goguma]: https://sr.ht/~emersion/goguma/ +[senpai]: https://sr.ht/~taiite/senpai/ +[Weechat]: https://weechat.org/ +[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py +[Hexchat]: https://hexchat.github.io/ +[Emacs]: https://www.gnu.org/software/emacs/ diff --git a/branches/master/db.go b/branches/master/db.go new file mode 100644 index 0000000..95ac964 --- /dev/null +++ b/branches/master/db.go @@ -0,0 +1,184 @@ +package suika + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" +) + +type Database interface { + Close() error + Stats(ctx context.Context) (*DatabaseStats, error) + + ListUsers(ctx context.Context) ([]User, error) + GetUser(ctx context.Context, username string) (*User, error) + StoreUser(ctx context.Context, user *User) error + DeleteUser(ctx context.Context, id int64) error + + ListNetworks(ctx context.Context, userID int64) ([]Network, error) + StoreNetwork(ctx context.Context, userID int64, network *Network) error + DeleteNetwork(ctx context.Context, id int64) error + ListChannels(ctx context.Context, networkID int64) ([]Channel, error) + StoreChannel(ctx context.Context, networKID int64, ch *Channel) error + DeleteChannel(ctx context.Context, id int64) error + + ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) + StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error + + GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) + StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error +} + +func OpenDB(driver, source string) (Database, error) { + switch driver { + case "sqlite3": + return OpenSqliteDB(source) + case "postgres": + return OpenPostgresDB(source) + default: + return nil, fmt.Errorf("unsupported database driver: %q", driver) + } +} + +type DatabaseStats struct { + Users int64 + Networks int64 + Channels int64 +} + +type User struct { + ID int64 + Username string + Password string // hashed + Realname string + Admin bool +} + +type SASL struct { + Mechanism string + + Plain struct { + Username string + Password string + } + + // TLS client certificate authentication. + External struct { + // X.509 certificate in DER form. + CertBlob []byte + // PKCS#8 private key in DER form. + PrivKeyBlob []byte + } +} + +type Network struct { + ID int64 + Name string + Addr string + Nick string + Username string + Realname string + Pass string + ConnectCommands []string + SASL SASL + Enabled bool +} + +func (net *Network) GetName() string { + if net.Name != "" { + return net.Name + } + return net.Addr +} + +func (net *Network) URL() (*url.URL, error) { + s := net.Addr + if !strings.Contains(s, "://") { + // This is a raw domain name, make it an URL with the default scheme + s = "ircs://" + s + } + + u, err := url.Parse(s) + if err != nil { + return nil, fmt.Errorf("failed to parse upstream server URL: %v", err) + } + + return u, nil +} + +func GetNick(user *User, net *Network) string { + if net.Nick != "" { + return net.Nick + } + return user.Username +} + +func GetUsername(user *User, net *Network) string { + if net.Username != "" { + return net.Username + } + return GetNick(user, net) +} + +func GetRealname(user *User, net *Network) string { + if net.Realname != "" { + return net.Realname + } + if user.Realname != "" { + return user.Realname + } + return GetNick(user, net) +} + +type MessageFilter int + +const ( + // TODO: use customizable user defaults for FilterDefault + FilterDefault MessageFilter = iota + FilterNone + FilterHighlight + FilterMessage +) + +func parseFilter(filter string) (MessageFilter, error) { + switch filter { + case "default": + return FilterDefault, nil + case "none": + return FilterNone, nil + case "highlight": + return FilterHighlight, nil + case "message": + return FilterMessage, nil + } + return 0, fmt.Errorf("unknown filter: %q", filter) +} + +type Channel struct { + ID int64 + Name string + Key string + + Detached bool + DetachedInternalMsgID string + + RelayDetached MessageFilter + ReattachOn MessageFilter + DetachAfter time.Duration + DetachOn MessageFilter +} + +type DeliveryReceipt struct { + ID int64 + Target string // channel or nick + Client string + InternalMsgID string +} + +type ReadReceipt struct { + ID int64 + Target string // channel or nick + Timestamp time.Time +} diff --git a/branches/master/db_postgres.go b/branches/master/db_postgres.go new file mode 100644 index 0000000..0849338 --- /dev/null +++ b/branches/master/db_postgres.go @@ -0,0 +1,494 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "errors" + "fmt" + "math" + "strings" + "time" + + _ "github.com/lib/pq" +) + +const postgresQueryTimeout = 5 * time.Second + +const postgresConfigSchema = ` +CREATE TABLE IF NOT EXISTS "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); +` +//go:embed suika_psql_schema.sql +var postgresSchema string + +var postgresMigrations = []string{ + "", // migration #0 is reserved for schema initialization + `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`, + ` + CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + ALTER TABLE "Network" + ALTER COLUMN sasl_mechanism + TYPE sasl_mechanism + USING sasl_mechanism::sasl_mechanism; + `, + ` + CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) + ); + `, +} + +type PostgresDB struct { + db *sql.DB +} + +func OpenPostgresDB(source string) (Database, error) { + sqlPostgresDB, err := sql.Open("postgres", source) + if err != nil { + return nil, err + } + + db := &PostgresDB{db: sqlPostgresDB} + if err := db.upgrade(); err != nil { + sqlPostgresDB.Close() + return nil, err + } + + return db, nil +} + +func (db *PostgresDB) upgrade() error { + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec(postgresConfigSchema); err != nil { + return fmt.Errorf("failed to create Config table: %s", err) + } + + var version int + err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to query schema version: %s", err) + } + + if version == len(postgresMigrations) { + return nil + } + if version > len(postgresMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version) + } + + if version == 0 { + if _, err := tx.Exec(postgresSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %s", err) + } + } else { + for i := version; i < len(postgresMigrations); i++ { + if _, err := tx.Exec(postgresMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1) + ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations)) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *PostgresDB) Close() error { + return db.db.Close() +} + +func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM "User") AS users, + (SELECT COUNT(*) FROM "Network") AS networks, + (SELECT COUNT(*) FROM "Channel") AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT id, username, password, admin, realname FROM "User"`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + `SELECT id, password, admin, realname FROM "User" WHERE username = $1`, + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + password := toNullString(user.Password) + realname := toNullString(user.Realname) + + var err error + if user.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "User" (username, password, admin, realname) + VALUES ($1, $2, $3, $4) + RETURNING id`, + user.Username, password, user.Admin, realname).Scan(&user.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "User" + SET password = $1, admin = $2, realname = $3 + WHERE id = $4`, + password, user.Admin, realname, user.ID) + } + return err +} + +func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, + sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled + FROM "Network" + WHERE "user" = $1`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + netName := toNullString(network.Name) + nick := toNullString(network.Nick) + netUsername := toNullString(network.Username) + realname := toNullString(network.Realname) + pass := toNullString(network.Pass) + connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + var err error + if network.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, + sasl_external_key, enabled) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id`, + userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Network" + SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, + connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, + sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13, + enabled = $14 + WHERE id = $1`, + network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled) + } + return err +} + +func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, + detach_on + FROM "Channel" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + key := toNullString(ch.Key) + detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds())) + + var err error + if ch.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, + detach_after, detach_on) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING id`, + networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Channel" + SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5, + relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9 + WHERE id = $1`, + ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn) + } + return err +} + +func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM "DeliveryReceipt" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, + `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`, + networkID, client) + if err != nil { + return err + } + + stmt, err := tx.PrepareContext(ctx, ` + INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid) + VALUES ($1, $2, $3, $4) + RETURNING id`) + if err != nil { + return err + } + defer stmt.Close() + + for i := range receipts { + rcpt := &receipts[i] + err := stmt. + QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID). + Scan(&rcpt.ID) + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, + `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`, + networkID, name) + if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return receipt, nil +} + +func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE "ReadReceipt" + SET timestamp = $1 + WHERE id = $2`, + receipt.Timestamp, receipt.ID) + } else { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "ReadReceipt" (network, target, timestamp) + VALUES ($1, $2, $3) + RETURNING id`, + networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID) + } + return err +} diff --git a/branches/master/db_postgres_test.go b/branches/master/db_postgres_test.go new file mode 100644 index 0000000..7692f31 --- /dev/null +++ b/branches/master/db_postgres_test.go @@ -0,0 +1,104 @@ +package suika + +import ( + "database/sql" + "os" + "testing" +) + +// PostgreSQL version 0 schema. DO NOT EDIT. +const postgresV0Schema = ` +CREATE TABLE "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); + +INSERT INTO "Config" (id, version) VALUES (1, 1); + +CREATE TABLE "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TABLE "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA DEFAULT NULL, + sasl_external_key BYTEA DEFAULT NULL, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); + +CREATE TABLE "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); + +CREATE TABLE "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +` + +func openTempPostgresDB(t *testing.T) *sql.DB { + source, ok := os.LookupEnv("SOJU_TEST_POSTGRES") + if !ok { + t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests") + } + + db, err := sql.Open("postgres", source) + if err != nil { + t.Fatalf("failed to connect to PostgreSQL: %v", err) + } + + // Store all tables in a temporary schema which will be dropped when the + // connection to PostgreSQL is closed. + db.SetMaxOpenConns(1) + if _, err := db.Exec("SET search_path TO pg_temp"); err != nil { + t.Fatalf("failed to set PostgreSQL search_path: %v", err) + } + + return db +} + +func TestPostgresMigrations(t *testing.T) { + sqlDB := openTempPostgresDB(t) + if _, err := sqlDB.Exec(postgresV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &PostgresDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("PostgresDB.Upgrade() failed: %v", err) + } +} diff --git a/branches/master/db_sqlite.go b/branches/master/db_sqlite.go new file mode 100644 index 0000000..7ebebb0 --- /dev/null +++ b/branches/master/db_sqlite.go @@ -0,0 +1,755 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "math" + "strings" + "sync" + "time" + + _ "modernc.org/sqlite" +) + +const sqliteQueryTimeout = 5 * time.Second + +//go:embed suika_sqlite_schema.sql +var sqliteSchema string + +var sqliteMigrations = []string{ + "", // migration #0 is reserved for schema initialization + "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)", + "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL", + "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL", + "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0", + ` + CREATE TABLE IF NOT EXISTS UserNew ( + id INTEGER PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin INTEGER NOT NULL DEFAULT 0 + ); + INSERT INTO UserNew SELECT rowid, username, password, admin FROM User; + DROP TABLE User; + ALTER TABLE UserNew RENAME TO User; + `, + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user INTEGER NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BLOB DEFAULT NULL, + sasl_external_key BLOB DEFAULT NULL, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT Network.id, name, User.id as user, addr, nick, + Network.username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key + FROM Network + JOIN User ON Network.user = User.username; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0; + `, + ` + CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target VARCHAR(255) NOT NULL, + client VARCHAR(255), + internal_msgid VARCHAR(255) NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) + ); + `, + "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)", + "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", + "ALTER TABLE User ADD COLUMN realname VARCHAR(255)", + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT id, name, user, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, + enabled + FROM Network; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) + ); + `, +} + +type SqliteDB struct { + lock sync.RWMutex + db *sql.DB +} + +func OpenSqliteDB(source string) (Database, error) { + sqlSqliteDB, err := sql.Open("sqlite", source) + if err != nil { + return nil, err + } + + db := &SqliteDB{db: sqlSqliteDB} + if err := db.upgrade(); err != nil { + sqlSqliteDB.Close() + return nil, err + } + + return db, nil +} + +func (db *SqliteDB) Close() error { + db.lock.Lock() + defer db.lock.Unlock() + return db.db.Close() +} + +func (db *SqliteDB) upgrade() error { + db.lock.Lock() + defer db.lock.Unlock() + + var version int + if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + return fmt.Errorf("failed to query schema version: %v", err) + } + + if version == len(sqliteMigrations) { + return nil + } else if version > len(sqliteMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version) + } + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if version == 0 { + if _, err := tx.Exec(sqliteSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %v", err) + } + } else { + for i := version; i < len(sqliteMigrations); i++ { + if _, err := tx.Exec(sqliteMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + // For some reason prepared statements don't work here + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations))) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM User) AS users, + (SELECT COUNT(*) FROM Network) AS networks, + (SELECT COUNT(*) FROM Channel) AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func toNullString(s string) sql.NullString { + return sql.NullString{ + String: s, + Valid: s != "", + } +} + +func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + "SELECT id, username, password, admin, realname FROM User") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + "SELECT id, password, admin, realname FROM User WHERE username = ?", + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("username", user.Username), + sql.Named("password", toNullString(user.Password)), + sql.Named("admin", user.Admin), + sql.Named("realname", toNullString(user.Realname)), + } + + var err error + if user.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE User SET password = :password, admin = :admin, + realname = :realname WHERE username = :username`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + User(username, password, admin, realname) + VALUES (:username, :password, :admin, :realname)`, + args...) + if err != nil { + return err + } + user.ID, err = res.LastInsertId() + } + + return err +} + +func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt + WHERE id IN ( + SELECT DeliveryReceipt.id + FROM DeliveryReceipt + JOIN Network ON DeliveryReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt + WHERE id IN ( + SELECT ReadReceipt.id + FROM ReadReceipt + JOIN Network ON ReadReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM Channel + WHERE id IN ( + SELECT Channel.id + FROM Channel + JOIN Network ON Channel.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key, enabled + FROM Network + WHERE user = ?`, + userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + args := []interface{}{ + sql.Named("name", toNullString(network.Name)), + sql.Named("addr", network.Addr), + sql.Named("nick", toNullString(network.Nick)), + sql.Named("username", toNullString(network.Username)), + sql.Named("realname", toNullString(network.Realname)), + sql.Named("pass", toNullString(network.Pass)), + sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))), + sql.Named("sasl_mechanism", saslMechanism), + sql.Named("sasl_plain_username", saslPlainUsername), + sql.Named("sasl_plain_password", saslPlainPassword), + sql.Named("sasl_external_cert", network.SASL.External.CertBlob), + sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob), + sql.Named("enabled", network.Enabled), + + sql.Named("id", network.ID), // only for UPDATE + sql.Named("user", userID), // only for INSERT + } + + var err error + if network.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE Network + SET name = :name, addr = :addr, nick = :nick, username = :username, + realname = :realname, pass = :pass, connect_commands = :connect_commands, + sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password, + sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key, + enabled = :enabled + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO Network(user, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, enabled) + VALUES (:user, :name, :addr, :nick, :username, :realname, :pass, + :connect_commands, :sasl_mechanism, :sasl_plain_username, + :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`, + args...) + if err != nil { + return err + } + network.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, `SELECT + id, name, key, detached, detached_internal_msgid, + relay_detached, reattach_on, detach_after, detach_on + FROM Channel + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("network", networkID), + sql.Named("name", ch.Name), + sql.Named("key", toNullString(ch.Key)), + sql.Named("detached", ch.Detached), + sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)), + sql.Named("relay_detached", ch.RelayDetached), + sql.Named("reattach_on", ch.ReattachOn), + sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))), + sql.Named("detach_on", ch.DetachOn), + + sql.Named("id", ch.ID), // only for UPDATE + } + + var err error + if ch.ID != 0 { + _, err = db.db.ExecContext(ctx, `UPDATE Channel + SET network = :network, name = :name, key = :key, detached = :detached, + detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached, + reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) + VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...) + if err != nil { + return err + } + ch.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id) + return err +} + +func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM DeliveryReceipt + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + var client sql.NullString + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + rcpt.Client = client.String + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?", + networkID, toNullString(client)) + if err != nil { + return err + } + + for i := range receipts { + rcpt := &receipts[i] + + res, err := tx.ExecContext(ctx, ` + INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) + VALUES (:network, :target, :client, :internal_msgid)`, + sql.Named("network", networkID), + sql.Named("target", rcpt.Target), + sql.Named("client", toNullString(client)), + sql.Named("internal_msgid", rcpt.InternalMsgID)) + if err != nil { + return err + } + rcpt.ID, err = res.LastInsertId() + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, ` + SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`, + sql.Named("network", networkID), + sql.Named("target", name), + ) + var timestamp string + if err := row.Scan(&receipt.ID, ×tamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + if t, err := time.Parse(serverTimeLayout, timestamp); err != nil { + return nil, err + } else { + receipt.Timestamp = t + } + return receipt, nil +} + +func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("id", receipt.ID), + sql.Named("timestamp", formatServerTime(receipt.Timestamp)), + sql.Named("network", networkID), + sql.Named("target", receipt.Target), + } + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + ReadReceipt(network, target, timestamp) + VALUES (:network, :target, :timestamp)`, + args...) + if err != nil { + return err + } + receipt.ID, err = res.LastInsertId() + } + + return err +} diff --git a/branches/master/db_sqlite_test.go b/branches/master/db_sqlite_test.go new file mode 100644 index 0000000..a7c9794 --- /dev/null +++ b/branches/master/db_sqlite_test.go @@ -0,0 +1,59 @@ +package suika + +import ( + "database/sql" + "testing" +) + +// SQLite version 0 schema. DO NOT EDIT. +const sqliteV0Schema = ` +CREATE TABLE User ( + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255) +); + +CREATE TABLE Network ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user VARCHAR(255) NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); + +CREATE TABLE Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +PRAGMA user_version = 1; +` + +func TestSqliteMigrations(t *testing.T) { + sqlDB, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + + if _, err := sqlDB.Exec(sqliteV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &SqliteDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("SqliteDB.Upgrade() failed: %v", err) + } +} diff --git a/branches/master/doc.go b/branches/master/doc.go new file mode 100644 index 0000000..3f7af51 --- /dev/null +++ b/branches/master/doc.go @@ -0,0 +1,20 @@ +// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go. +// +// # Copyright (C) 2020 The soju Contributors +// # Copyright (C) 2023-present Izuru Yakumo et al. +// +// suika is covered by the AGPLv3 license: +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . +package suika diff --git a/branches/master/doc/suika-config.5 b/branches/master/doc/suika-config.5 new file mode 100644 index 0000000..ac0e9ef --- /dev/null +++ b/branches/master/doc/suika-config.5 @@ -0,0 +1,70 @@ +.Dd $Mdocdate$ +.Dt SUIKA-CONFIG 5 +.Os +.Sh NAME +.Nm suika-config +.Nd Configuration file for the IRC bouncer +.Sh SYNOPSIS +.Bk -words +listen ircs:// +.Pp +tls cert.pem key.pem +.Pp +hostname example.org +.Ek +.Sh DESCRIPTION +This document describes the format of the configuration +file used by +.Xr suika 1 +.Sh OPTIONS +.Bl -tag -width Ds +.It listen Ar uri +With this you can control on what +ports/protocols +.Xr suika 1 +listens on, it supports +irc (cleartext IRC), ircs (IRC with TLS), and unix +(IRC over Unix domain sockets) +.It hostname Ar hostname +Server hostname, if unset, the system one is used. +.It title Ar title +Server title, this will be sent as the ISUPPORT NETWORK value when +clients don't select a specific network. +.It tls Ar cert Ar key +Enable TLS support, the certificate and key files must be +PEM-encoded. +.It db Ar driver Ar path +Set the database driver for user, network and channel storage. +By default a SQLite 3 database is opened in +.Pa ./suika.db +Supported drivers are sqlite and postgres, the former +expects a path to the database file, and the latter +a space-separated list of key=value parameters, +e.g. host=localhost dbname=suika +.It log fs Ar path +Path to the bouncer logs directory, or empty to disable +logging. +By default, logging is disabled. +.It max-user-networks Ar limit +Maximum number of networks per user, by default +there is no limit. +.It motd Ar path +Path to the MOTD file, its contents are sent to clients +which aren't bound to a particular network. +By default, no MOTD is sent. +.It multi-upstream-mode Ar bool +Globally enable or disable multi-upstream mode. +By default, it is enabled. +.It upstream-user-ip Ar cidr +Enable per-user-IP addresses. +One IPv4 and/or one IPv6 range can be specified in CIDR notation. +One IP address per range will be assigned to each user as the +source address when connecting to an upstream network. +This can be useful to avoid having the whole bouncer banned from +an upstream network because of one malicious user. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/master/doc/suika-znc-import.1 b/branches/master/doc/suika-znc-import.1 new file mode 100644 index 0000000..7acc618 --- /dev/null +++ b/branches/master/doc/suika-znc-import.1 @@ -0,0 +1,34 @@ +.Dd $Mdocdate$ +.Dt SUIKA-ZNC-IMPORT 1 +.Os +.Sh NAME +.Nm suika-znc-import +.Nd Migration utility for moving from ZNC +.Sh SYNOPSIS +.Nm +.Op Fl config Ar suika config file +.Op Fl user Ar username +.Op Fl network Ar name +.Sh DESCRIPTION +Imports configuration from a ZNC file. +Users and networks are merged if they already exist in the +.Xr suika 1 +database. +ZNC settings overwrite existing +.Xr suika 1 +settings +.Sh OPTIONS +.Bl -tag -width Ds +.It config Ar suika config file +Path to +.Xr suika-config 5 +.It user Ar username +Limit import to username, may be specified multiple times. +It network Ar name +Limit import to network, may be specified multiple times. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/master/doc/suika.1 b/branches/master/doc/suika.1 new file mode 100644 index 0000000..afe6607 --- /dev/null +++ b/branches/master/doc/suika.1 @@ -0,0 +1,68 @@ +.Dd $Mdocdate$ +.Dt SUIKA 1 +.Os +.Sh NAME +.Nm suika +.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project +.Sh SYNOPSIS +.Nm +.Op Fl config Ar path +.Op Fl debug +.Op Fl listen Ar uri +.Sh DESCRIPTION +.Nm +is an user-friendly IRC bouncer, it connects to upstream +IRC servers on behalf of the user to provide extra features. +.Bl -tag -width 6n +.It Multiple separate users sharing the same bouncer +.It Clients connecting to multiple upstream servers (via a single connection) +.It Sending the backlog with per-client buffers +.El +.Pp +When joining a channel, the channel will be saved +and automatically joined on the next connection. +When registering or authenticating with NickServ, the credentials will be saved +and automatically used on the next connection if the server supports SASL. +When parting a channel with the reason "detach", the channel will be +detached instead of being left. +When all clients are disconnected from the bouncer, +the user is automatically marked as away. +.Pp +.Nm +supports two connection modes: +.Bl -tag -width 6n +.It Single upstream mode +One downstream connection maps to one upstream connection +.Pp +To enable this mode, connect to the bouncer +with the username "/". +.Pp +If the bouncer isn't connected to the upstream server, +it will get automatically added. +.Pp +Then channels can be joined and parted as if +you were directly connected to the upstream server. +.It Multiple upstream mode +One downstream connection maps to multiple upstream connections. +Channels and nicks are suffixed with the network name. +To join a channel, you need to use the suffix too: /join #channel/network. +Same applies to messages sent to users. +.El +.Pp +For per-client history to work, clients need to indicate their name. +This can be done by adding a "@" suffix to the username. +.Pp +.Nm +will reload the configuration file, the TLS certificate/key and +the MOTD file when it receives the HUP signal. +The configuration options listen, db and log cannot be reloaded. +.Pp +Administrators can broadcast a message to all bouncer users via +/notice $ , or via /notice $ in multi-upstream mode. +All currently connected bouncer users will receive the message +from the special BouncerServ service. +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/master/doc/suikadb.1 b/branches/master/doc/suikadb.1 new file mode 100644 index 0000000..5ff4769 --- /dev/null +++ b/branches/master/doc/suikadb.1 @@ -0,0 +1,16 @@ +.Dd $Mdocdate$ +.Dt SUIKADB 1 +.Os +.Sh NAME +.Nm suikadb +.Nd Basic user manipulation for +.Xr suika 1 +.Sh SYNOPSIS +.Nm +.Op create-user +.Op change-password +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/master/downstream.go b/branches/master/downstream.go new file mode 100644 index 0000000..3b9adac --- /dev/null +++ b/branches/master/downstream.go @@ -0,0 +1,3047 @@ +package suika + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +type ircError struct { + Message *irc.Message +} + +func (err ircError) Error() string { + return err.Message.String() +} + +func newUnknownCommandError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{ + "*", + cmd, + "Unknown command", + }, + }} +} + +func newNeedMoreParamsError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_NEEDMOREPARAMS, + Params: []string{ + "*", + cmd, + "Not enough parameters", + }, + }} +} + +func newChatHistoryError(subcommand string, target string) ircError { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"}, + }} +} + +// authError is an authentication error. +type authError struct { + // Internal error cause. This will not be revealed to the user. + err error + // Error cause which can safely be sent to the user without compromising + // security. + reason string +} + +func (err *authError) Error() string { + return err.err.Error() +} + +func (err *authError) Unwrap() error { + return err.err +} + +// authErrorReason returns the user-friendly reason of an authentication +// failure. +func authErrorReason(err error) string { + if authErr, ok := err.(*authError); ok { + return authErr.reason + } else { + return "Authentication failed" + } +} + +func newInvalidUsernameOrPasswordError(err error) error { + return &authError{ + err: err, + reason: "Invalid username or password", + } +} + +func parseBouncerNetID(subcommand, s string) (int64, error) { + id, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"}, + }} + } + return id, nil +} + +func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) { + u, err := network.URL() + if err != nil { + return + } + + hasHostPort := true + switch u.Scheme { + case "ircs": + attrs["tls"] = irc.TagValue("1") + case "irc": + attrs["tls"] = irc.TagValue("0") + default: // e.g. unix:// + hasHostPort = false + } + if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort { + attrs["host"] = irc.TagValue(host) + attrs["port"] = irc.TagValue(port) + } else if hasHostPort { + attrs["host"] = irc.TagValue(u.Host) + } +} + +func getNetworkAttrs(network *network) irc.Tags { + state := "disconnected" + if uc := network.conn; uc != nil { + state = "connected" + } + + attrs := irc.Tags{ + "name": irc.TagValue(network.GetName()), + "state": irc.TagValue(state), + "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)), + } + + if network.Username != "" { + attrs["username"] = irc.TagValue(network.Username) + } + if realname := GetRealname(&network.user.User, &network.Network); realname != "" { + attrs["realname"] = irc.TagValue(realname) + } + + fillNetworkAddrAttrs(attrs, &network.Network) + + return attrs +} + +func networkAddrFromAttrs(attrs irc.Tags) string { + host, ok := attrs.GetTag("host") + if !ok { + return "" + } + + addr := host + if port, ok := attrs.GetTag("port"); ok { + addr += ":" + port + } + + if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" { + addr = "irc://" + tlsStr + } + + return addr +} + +func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error { + addrAttrs := irc.Tags{} + fillNetworkAddrAttrs(addrAttrs, record) + + updateAddr := false + for k, v := range attrs { + s := string(v) + switch k { + case "host", "port", "tls": + updateAddr = true + addrAttrs[k] = v + case "name": + record.Name = s + case "nickname": + record.Nick = s + case "username": + record.Username = s + case "realname": + record.Realname = s + case "pass": + record.Pass = s + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"}, + }} + } + } + + if updateAddr { + record.Addr = networkAddrFromAttrs(addrAttrs) + if record.Addr == "" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"}, + }} + } + } + + return nil +} + +// illegalNickChars is the list of characters forbidden in a nickname. +// +// ' ' and ':' break the IRC message wire format +// '@' and '!' break prefixes +// '*' breaks masks and is the reserved nickname for registration +// '?' breaks masks +// '$' breaks server masks in PRIVMSG/NOTICE +// ',' breaks lists +// '.' is reserved for server names +const illegalNickChars = " :@!*?$,." + +// permanentDownstreamCaps is the list of always-supported downstream +// capabilities. +var permanentDownstreamCaps = map[string]string{ + "batch": "", + "cap-notify": "", + "echo-message": "", + "invite-notify": "", + "message-tags": "", + "server-time": "", + "setname": "", + + "soju.im/bouncer-networks": "", + "soju.im/bouncer-networks-notify": "", + "soju.im/read": "", +} + +// needAllDownstreamCaps is the list of downstream capabilities that +// require support from all upstreams to be enabled +var needAllDownstreamCaps = map[string]string{ + "account-notify": "", + "account-tag": "", + "away-notify": "", + "extended-join": "", + "multi-prefix": "", + + "draft/extended-monitor": "", +} + +// passthroughIsupport is the set of ISUPPORT tokens that are directly passed +// through from the upstream server to downstream clients. +// +// This is only effective in single-upstream mode. +var passthroughIsupport = map[string]bool{ + "AWAYLEN": true, + "BOT": true, + "CHANLIMIT": true, + "CHANMODES": true, + "CHANNELLEN": true, + "CHANTYPES": true, + "CLIENTTAGDENY": true, + "ELIST": true, + "EXCEPTS": true, + "EXTBAN": true, + "HOSTLEN": true, + "INVEX": true, + "KICKLEN": true, + "MAXLIST": true, + "MAXTARGETS": true, + "MODES": true, + "MONITOR": true, + "NAMELEN": true, + "NETWORK": true, + "NICKLEN": true, + "PREFIX": true, + "SAFELIST": true, + "TARGMAX": true, + "TOPICLEN": true, + "USERLEN": true, + "UTF8ONLY": true, + "WHOX": true, +} + +type downstreamSASL struct { + server sasl.Server + plainUsername, plainPassword string + pendingResp bytes.Buffer +} + +type downstreamConn struct { + conn + + id uint64 + + registered bool + user *user + nick string + nickCM string + rawUsername string + networkName string + clientName string + realname string + hostname string + account string // RPL_LOGGEDIN/OUT state + password string // empty after authentication + network *network // can be nil + isMultiUpstream bool + + negotiatingCaps bool + capVersion int + supportedCaps map[string]string + caps map[string]bool + sasl *downstreamSASL + + lastBatchRef uint64 + + monitored casemapMap +} + +func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { + remoteAddr := ic.RemoteAddr().String() + logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} + options := connOptions{Logger: logger} + dc := &downstreamConn{ + conn: *newConn(srv, ic, &options), + id: id, + nick: "*", + nickCM: "*", + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + monitored: newCasemapMap(0), + } + dc.hostname = remoteAddr + if host, _, err := net.SplitHostPort(dc.hostname); err == nil { + dc.hostname = host + } + for k, v := range permanentDownstreamCaps { + dc.supportedCaps[k] = v + } + dc.supportedCaps["sasl"] = "PLAIN" + // TODO: this is racy, we should only enable chathistory after + // authentication and then check that user.msgStore implements + // chatHistoryMessageStore + if srv.Config().LogPath != "" { + dc.supportedCaps["draft/chathistory"] = "" + } + return dc +} + +func (dc *downstreamConn) prefix() *irc.Prefix { + return &irc.Prefix{ + Name: dc.nick, + User: dc.user.Username, + Host: dc.hostname, + } +} + +func (dc *downstreamConn) forEachNetwork(f func(*network)) { + if dc.network != nil { + f(dc.network) + } else if dc.isMultiUpstream { + for _, network := range dc.user.networks { + f(network) + } + } +} + +func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + dc.user.forEachUpstream(func(uc *upstreamConn) { + if dc.network != nil && uc.network != dc.network { + return + } + f(uc) + }) +} + +// upstream returns the upstream connection, if any. If there are zero or if +// there are multiple upstream connections, it returns nil. +func (dc *downstreamConn) upstream() *upstreamConn { + if dc.network == nil { + return nil + } + return dc.network.conn +} + +func isOurNick(net *network, nick string) bool { + // TODO: this doesn't account for nick changes + if net.conn != nil { + return net.casemap(nick) == net.conn.nickCM + } + // We're not currently connected to the upstream connection, so we don't + // know whether this name is our nickname. Best-effort: use the network's + // configured nickname and hope it was the one being used when we were + // connected. + return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network)) +} + +// marshalEntity converts an upstream entity name (ie. channel or nick) into a +// downstream entity name. +// +// This involves adding a "/" suffix if the entity isn't the current +// user. +func (dc *downstreamConn) marshalEntity(net *network, name string) string { + if isOurNick(net, name) { + return dc.nick + } + name = partialCasemap(net.casemap, name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal an entity for another network") + } + return name + } + return name + "/" + net.GetName() +} + +func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix { + if isOurNick(net, prefix.Name) { + return dc.prefix() + } + prefix.Name = partialCasemap(net.casemap, prefix.Name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal a user prefix for another network") + } + return prefix + } + return &irc.Prefix{ + Name: prefix.Name + "/" + net.GetName(), + User: prefix.User, + Host: prefix.Host, + } +} + +// unmarshalEntityNetwork converts a downstream entity name (ie. channel or +// nick) into an upstream entity name. +// +// This involves removing the "/" suffix. +func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) { + if dc.network != nil { + return dc.network, name, nil + } + if !dc.isMultiUpstream { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"}, + }} + } + + var net *network + if i := strings.LastIndexByte(name, '/'); i >= 0 { + network := name[i+1:] + name = name[:i] + + for _, n := range dc.user.networks { + if network == n.GetName() { + net = n + break + } + } + } + + if net == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Missing network suffix in name"}, + }} + } + + return net, name, nil +} + +// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the +// upstream connection and fails if the upstream is disconnected. +func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) { + net, name, err := dc.unmarshalEntityNetwork(name) + if err != nil { + return nil, "", err + } + + if net.conn == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Disconnected from upstream network"}, + }} + } + + return net.conn, name, nil +} + +func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string { + if dc.upstream() != nil { + return text + } + // TODO: smarter parsing that ignores URLs + return strings.ReplaceAll(text, "/"+uc.network.GetName(), "") +} + +func (dc *downstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := dc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (dc *downstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := dc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventDownstreamMessage{msg, dc} + } + + return nil +} + +// SendMessage sends an outgoing message. +// +// This can only called from the user goroutine. +func (dc *downstreamConn) SendMessage(msg *irc.Message) { + if !dc.caps["message-tags"] { + if msg.Command == "TAGMSG" { + return + } + msg = msg.Copy() + for name := range msg.Tags { + supported := false + switch name { + case "time": + supported = dc.caps["server-time"] + case "account": + supported = dc.caps["account"] + } + if !supported { + delete(msg.Tags, name) + } + } + } + if !dc.caps["batch"] && msg.Tags["batch"] != "" { + msg = msg.Copy() + delete(msg.Tags, "batch") + } + if msg.Command == "JOIN" && !dc.caps["extended-join"] { + msg.Params = msg.Params[:1] + } + if msg.Command == "SETNAME" && !dc.caps["setname"] { + return + } + if msg.Command == "AWAY" && !dc.caps["away-notify"] { + return + } + if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] { + return + } + if msg.Command == "READ" && !dc.caps["soju.im/read"] { + return + } + + dc.conn.SendMessage(context.TODO(), msg) +} + +func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) { + dc.lastBatchRef++ + ref := fmt.Sprintf("%v", dc.lastBatchRef) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Tags: tags, + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: append([]string{"+" + ref, typ}, params...), + }) + } + + f(irc.TagValue(ref)) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: []string{"-" + ref}, + }) + } +} + +// sendMessageWithID sends an outgoing message with the specified internal ID. +func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { + dc.SendMessage(msg) + + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// advanceMessageWithID advances history to the specified message ID without +// sending a message. This is useful e.g. for self-messages when echo-message +// isn't enabled. +func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) { + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// ackMsgID acknowledges that a message has been received. +func (dc *downstreamConn) ackMsgID(id string) { + netID, entity, err := parseMsgID(id, nil) + if err != nil { + dc.logger.Printf("failed to ACK message ID %q: %v", id, err) + return + } + + network := dc.user.getNetworkByID(netID) + if network == nil { + return + } + + network.delivered.StoreID(entity, dc.clientName, id) +} + +func (dc *downstreamConn) sendPing(msgID string) { + token := "suika-msgid-" + msgID + dc.SendMessage(&irc.Message{ + Command: "PING", + Params: []string{token}, + }) +} + +func (dc *downstreamConn) handlePong(token string) { + if !strings.HasPrefix(token, "suika-msgid-") { + dc.logger.Printf("received unrecognized PONG token %q", token) + return + } + msgID := strings.TrimPrefix(token, "suika-msgid-") + dc.ackMsgID(msgID) +} + +// marshalMessage re-formats a message coming from an upstream connection so +// that it's suitable for being sent on this downstream connection. Only +// messages that may appear in logs are supported, except MODE messages which +// may only appear in single-upstream mode. +func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message { + msg = msg.Copy() + msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix) + + if dc.network != nil { + return msg + } + + switch msg.Command { + case "PRIVMSG", "NOTICE", "TAGMSG": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "NICK": + // Nick change for another user + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "JOIN", "PART": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "KICK": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + msg.Params[1] = dc.marshalEntity(net, msg.Params[1]) + case "TOPIC": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "QUIT", "SETNAME": + // This space is intentionally left blank + default: + panic(fmt.Sprintf("unexpected %q message", msg.Command)) + } + + return msg +} + +func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + ctx, cancel := dc.conn.NewContext(ctx) + defer cancel() + + ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout) + defer cancel() + + switch msg.Command { + case "QUIT": + return dc.Close() + default: + if dc.registered { + return dc.handleMessageRegistered(ctx, msg) + } else { + return dc.handleMessageUnregistered(ctx, msg) + } + } +} + +func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "NICK": + var nick string + if err := parseMessageParams(msg, &nick); err != nil { + return err + } + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, nick, "contains illegal characters"}, + }} + } + nickCM := casemapASCII(nick) + if nickCM == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"}, + }} + } + dc.nick = nick + dc.nickCM = nickCM + case "USER": + if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil { + return err + } + case "PASS": + if err := parseMessageParams(msg, &dc.password); err != nil { + return err + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "AUTHENTICATE": + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } else if credentials == nil { + break + } + + if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil { + dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err) + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, authErrorReason(err)}, + }) + break + } + + // Technically we should send RPL_LOGGEDIN here. However we use + // RPL_LOGGEDIN to mirror the upstream connection status. Let's + // see how many clients that breaks. See: + // https://github.com/ircv3/ircv3-specifications/pull/476 + dc.endSASL(nil) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + + if dc.user == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"}, + }} + } + + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + var match *network + for _, net := range dc.user.networks { + if net.ID == id { + match = net + break + } + } + if match == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"}, + }} + } + + dc.networkName = match.GetName() + } + default: + dc.logger.Printf("unhandled message: %v", msg) + return newUnknownCommandError(msg.Command) + } + if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps { + return dc.register(ctx) + } + return nil +} + +func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { + cmd = strings.ToUpper(cmd) + + switch cmd { + case "LS": + if len(args) > 0 { + var err error + if dc.capVersion, err = strconv.Atoi(args[0]); err != nil { + return err + } + } + if !dc.registered && dc.capVersion >= 302 { + // Let downstream show everything it supports, and trim + // down the available capabilities when upstreams are + // known. + for k, v := range needAllDownstreamCaps { + dc.supportedCaps[k] = v + } + } + + caps := make([]string, 0, len(dc.supportedCaps)) + for k, v := range dc.supportedCaps { + if dc.capVersion >= 302 && v != "" { + caps = append(caps, k+"="+v) + } else { + caps = append(caps, k) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LS", strings.Join(caps, " ")}, + }) + + if dc.capVersion >= 302 { + // CAP version 302 implicitly enables cap-notify + dc.caps["cap-notify"] = true + } + + if !dc.registered { + dc.negotiatingCaps = true + } + case "LIST": + var caps []string + for name, enabled := range dc.caps { + if enabled { + caps = append(caps, name) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LIST", strings.Join(caps, " ")}, + }) + case "REQ": + if len(args) == 0 { + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"}, + }} + } + + // TODO: atomically ack/nak the whole capability set + caps := strings.Fields(args[0]) + ack := true + for _, name := range caps { + name = strings.ToLower(name) + enable := !strings.HasPrefix(name, "-") + if !enable { + name = strings.TrimPrefix(name, "-") + } + + if enable == dc.caps[name] { + continue + } + + _, ok := dc.supportedCaps[name] + if !ok { + ack = false + break + } + + if name == "cap-notify" && dc.capVersion >= 302 && !enable { + // cap-notify cannot be disabled with CAP version 302 + ack = false + break + } + + dc.caps[name] = enable + } + + reply := "NAK" + if ack { + reply = "ACK" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, reply, args[0]}, + }) + + if !dc.registered { + dc.negotiatingCaps = true + } + case "END": + dc.negotiatingCaps = false + default: + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Unknown CAP command"}, + }} + } + return nil +} + +func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) { + defer func() { + if err != nil { + dc.sasl = nil + } + }() + + if !dc.caps["sasl"] { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, + }} + } + if len(msg.Params) == 0 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Missing AUTHENTICATE argument"}, + }} + } + if msg.Params[0] == "*" { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }} + } + + var resp []byte + if dc.sasl == nil { + mech := strings.ToUpper(msg.Params[0]) + var server sasl.Server + switch mech { + case "PLAIN": + server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { + dc.sasl.plainUsername = username + dc.sasl.plainPassword = password + return nil + })) + default: + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)}, + }} + } + + dc.sasl = &downstreamSASL{server: server} + } else { + chunk := msg.Params[0] + if chunk == "+" { + chunk = "" + } + + if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Response too long"}, + }} + } + + dc.sasl.pendingResp.WriteString(chunk) + + if len(chunk) == maxSASLLength { + return nil, nil // Multi-line response, wait for the next command + } + + resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String()) + if err != nil { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Invalid base64-encoded response"}, + }} + } + + dc.sasl.pendingResp.Reset() + } + + challenge, done, err := dc.sasl.server.Next(resp) + if err != nil { + return nil, err + } else if done { + return dc.sasl, nil + } else { + challengeStr := "+" + if len(challenge) > 0 { + challengeStr = base64.StdEncoding.EncodeToString(challenge) + } + + // TODO: multi-line messages + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "AUTHENTICATE", + Params: []string{challengeStr}, + }) + return nil, nil + } +} + +func (dc *downstreamConn) endSASL(msg *irc.Message) { + if dc.sasl == nil { + return + } + + dc.sasl = nil + + if msg != nil { + dc.SendMessage(msg) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_SASLSUCCESS, + Params: []string{dc.nick, "SASL authentication successful"}, + }) + } +} + +func (dc *downstreamConn) setSupportedCap(name, value string) { + prevValue, hasPrev := dc.supportedCaps[name] + changed := !hasPrev || prevValue != value + dc.supportedCaps[name] = value + + if !dc.caps["cap-notify"] || !changed { + return + } + + cap := name + if value != "" && dc.capVersion >= 302 { + cap = name + "=" + value + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "NEW", cap}, + }) +} + +func (dc *downstreamConn) unsetSupportedCap(name string) { + _, hasPrev := dc.supportedCaps[name] + delete(dc.supportedCaps, name) + delete(dc.caps, name) + + if !dc.caps["cap-notify"] || !hasPrev { + return + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "DEL", name}, + }) +} + +func (dc *downstreamConn) updateSupportedCaps() { + supportedCaps := make(map[string]bool) + for cap := range needAllDownstreamCaps { + supportedCaps[cap] = true + } + dc.forEachUpstream(func(uc *upstreamConn) { + for cap, supported := range supportedCaps { + supportedCaps[cap] = supported && uc.caps[cap] + } + }) + + for cap, supported := range supportedCaps { + if supported { + dc.setSupportedCap(cap, needAllDownstreamCaps[cap]) + } else { + dc.unsetSupportedCap(cap) + } + } + + if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") { + dc.setSupportedCap("sasl", "PLAIN") + } else if dc.network != nil { + dc.unsetSupportedCap("sasl") + } + + if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { + // Strip "before-connect", because we require downstreams to be fully + // connected before attempting account registration. + values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") + for i, v := range values { + if v == "before-connect" { + values = append(values[:i], values[i+1:]...) + break + } + } + dc.setSupportedCap("draft/account-registration", strings.Join(values, ",")) + } else { + dc.unsetSupportedCap("draft/account-registration") + } + + if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil { + dc.setSupportedCap("draft/event-playback", "") + } else { + dc.unsetSupportedCap("draft/event-playback") + } +} + +func (dc *downstreamConn) updateNick() { + if uc := dc.upstream(); uc != nil && uc.nick != dc.nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{uc.nick}, + }) + dc.nick = uc.nick + dc.nickCM = casemapASCII(dc.nick) + } +} + +func (dc *downstreamConn) updateRealname() { + if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{uc.realname}, + }) + dc.realname = uc.realname + } +} + +func (dc *downstreamConn) updateAccount() { + var account string + if dc.network == nil { + account = dc.user.Username + } else if uc := dc.upstream(); uc != nil { + account = uc.account + } else { + return + } + + if dc.account == account || !dc.caps["sasl"] { + return + } + + if account != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDIN, + Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDOUT, + Params: []string{dc.nick, dc.prefix().String(), "You are logged out"}, + }) + } + + dc.account = account +} + +func sanityCheckServer(ctx context.Context, addr string) error { + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr) + if err != nil { + return err + } + + return conn.Close() +} + +func unmarshalUsername(rawUsername string) (username, client, network string) { + username = rawUsername + + i := strings.IndexAny(username, "/@") + j := strings.LastIndexAny(username, "/@") + if i >= 0 { + username = rawUsername[:i] + } + if j >= 0 { + if rawUsername[j] == '@' { + client = rawUsername[j+1:] + } else { + network = rawUsername[j+1:] + } + } + if i >= 0 && j >= 0 && i < j { + if rawUsername[i] == '@' { + client = rawUsername[i+1 : j] + } else { + network = rawUsername[i+1 : j] + } + } + + return username, client, network +} + +func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { + username, clientName, networkName := unmarshalUsername(username) + + u, err := dc.srv.db.GetUser(ctx, username) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err)) + } + + // Password auth disabled + if u.Password == "" { + return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled")) + } + + err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password")) + } + + dc.user = dc.srv.getUser(username) + if dc.user == nil { + return fmt.Errorf("user not active") + } + dc.clientName = clientName + dc.networkName = networkName + return nil +} + +func (dc *downstreamConn) register(ctx context.Context) error { + if dc.registered { + panic("tried to register twice") + } + + if dc.sasl != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + } + + password := dc.password + dc.password = "" + if dc.user == nil { + if password == "" { + if dc.caps["sasl"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"}, + }} + } else { + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, "Authentication required"}, + }} + } + } + + if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil { + dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, authErrorReason(err)}, + }} + } + } + + _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername) + if dc.clientName == "" { + dc.clientName = fallbackClientName + } else if fallbackClientName != "" && dc.clientName != fallbackClientName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Client name mismatch in usernames"}, + }} + } + if dc.networkName == "" { + dc.networkName = fallbackNetworkName + } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Network name mismatch in usernames"}, + }} + } + + dc.registered = true + dc.logger.Printf("registration complete for user %q", dc.user.Username) + return nil +} + +func (dc *downstreamConn) loadNetwork(ctx context.Context) error { + if dc.networkName == "" { + return nil + } + + network := dc.user.getNetwork(dc.networkName) + if network == nil { + addr := dc.networkName + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + dc.logger.Printf("trying to connect to new network %q", addr) + if err := sanityCheckServer(ctx, addr); err != nil { + dc.logger.Printf("failed to connect to %q: %v", addr, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)}, + }} + } + + // Some clients only allow specifying the nickname (and use the + // nickname as a username too). Strip the network name from the + // nickname when auto-saving networks. + nick, _, _ := unmarshalUsername(dc.nick) + + dc.logger.Printf("auto-saving network %q", dc.networkName) + var err error + network, err = dc.user.createNetwork(ctx, &Network{ + Addr: dc.networkName, + Nick: nick, + Enabled: true, + }) + if err != nil { + return err + } + } + + dc.network = network + return nil +} + +func (dc *downstreamConn) welcome(ctx context.Context) error { + if dc.user == nil || !dc.registered { + panic("tried to welcome an unregistered connection") + } + + remoteAddr := dc.conn.RemoteAddr().String() + dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)} + + // TODO: doing this might take some time. We should do it in dc.register + // instead, but we'll potentially be adding a new network and this must be + // done in the user goroutine. + if err := dc.loadNetwork(ctx); err != nil { + return err + } + + if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream { + dc.isMultiUpstream = true + } + + dc.updateSupportedCaps() + + isupport := []string{ + fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit), + "CASEMAPPING=ascii", + } + + if dc.network != nil { + isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID)) + } + if title := dc.srv.Config().Title; dc.network == nil && title != "" { + isupport = append(isupport, "NETWORK="+encodeISUPPORT(title)) + } + if dc.network == nil && !dc.isMultiUpstream { + isupport = append(isupport, "WHOX") + } + + if uc := dc.upstream(); uc != nil { + for k := range passthroughIsupport { + v, ok := uc.isupport[k] + if !ok { + continue + } + if v != nil { + isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v)) + } else { + isupport = append(isupport, k) + } + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WELCOME, + Params: []string{dc.nick, "Welcome to suika, " + dc.nick}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_YOURHOST, + Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MYINFO, + Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) { + dc.SendMessage(msg) + } + if uc := dc.upstream(); uc != nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + string(uc.modes)}, + }) + } + if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+o"}, + }) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + + if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil { + for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) { + dc.SendMessage(msg) + } + } else { + motdHint := "No MOTD" + if dc.network != nil { + motdHint = "Use /motd to read the message of the day" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOMOTD, + Params: []string{dc.nick, motdHint}, + }) + } + + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + } + + dc.forEachUpstream(func(uc *upstreamConn) { + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if !ch.complete { + continue + } + record := uc.network.channels.Value(ch.Name) + if record != nil && record.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)}, + }) + + forwardChannel(ctx, dc, ch) + } + }) + + dc.forEachNetwork(func(net *network) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + // Only send history if we're the first connected client with that name + // for the network + firstClient := true + dc.user.forEachDownstream(func(c *downstreamConn) { + if c != dc && c.clientName == dc.clientName && c.network == dc.network { + firstClient = false + } + }) + if firstClient { + net.delivered.ForEachTarget(func(target string) { + lastDelivered := net.delivered.LoadID(target, dc.clientName) + if lastDelivered == "" { + return + } + + dc.sendTargetBacklog(ctx, net, target, lastDelivered) + + // Fast-forward history to last message + targetCM := net.casemap(target) + lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now()) + if err != nil { + dc.logger.Printf("failed to get last message ID: %v", err) + return + } + net.delivered.StoreID(target, dc.clientName, lastID) + }) + } + }) + + return nil +} + +// messageSupportsBacklog checks whether the provided message can be sent as +// part of an history batch. +func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool { + // Don't replay all messages, because that would mess up client + // state. For instance we just sent the list of users, sending + // PART messages for one of these users would be incorrect. + switch msg.Command { + case "PRIVMSG", "NOTICE": + return true + } + return false +} + +func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + ch := net.channels.Value(target) + + ctx, cancel := context.WithTimeout(ctx, backlogTimeout) + defer cancel() + + targetCM := net.casemap(target) + history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit) + if err != nil { + dc.logger.Printf("failed to send backlog for %q: %v", target, err) + return + } + + dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + if ch != nil && ch.Detached { + if net.detachedMessageNeedsRelay(ch, msg) { + dc.relayDetachedMessage(net, msg) + } + } else { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, net)) + } + } + }) +} + +func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return + } + + sender := msg.Prefix.Name + target, text := msg.Params[0], msg.Params[1] + if net.isHighlight(msg) { + sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } else { + sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } +} + +func (dc *downstreamConn) runUntilRegistered() error { + ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout) + defer cancel() + + // Close the connection with an error if the deadline is exceeded + go func() { + <-ctx.Done() + if err := ctx.Err(); err == context.DeadlineExceeded { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "ERROR", + Params: []string{"Connection registration timed out"}, + }) + dc.Close() + } + }() + + for !dc.registered { + msg, err := dc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read IRC command: %w", err) + } + + err = dc.handleMessage(ctx, msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + return fmt.Errorf("failed to handle IRC command %q: %v", msg, err) + } + } + + return nil +} + +func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "PING": + var source, destination string + if err := parseMessageParams(msg, &source); err != nil { + return err + } + if len(msg.Params) > 1 { + destination = msg.Params[1] + } + hostname := dc.srv.Config().Hostname + if destination != "" && destination != hostname { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHSERVER, + Params: []string{dc.nick, destination, "No such server"}, + }} + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "PONG", + Params: []string{hostname, source}, + }) + return nil + case "PONG": + if len(msg.Params) == 0 { + return newNeedMoreParamsError(msg.Command) + } + token := msg.Params[len(msg.Params)-1] + dc.handlePong(token) + case "USER": + return ircError{&irc.Message{ + Command: irc.ERR_ALREADYREGISTERED, + Params: []string{dc.nick, "You may not reregister"}, + }} + case "NICK": + var rawNick string + if err := parseMessageParams(msg, &rawNick); err != nil { + return err + } + + nick := rawNick + var upstream *upstreamConn + if dc.upstream() == nil { + uc, unmarshaledNick, err := dc.unmarshalEntity(nick) + if err == nil { // NICK nick/network: NICK only on a specific upstream + upstream = uc + nick = unmarshaledNick + } + } + + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, rawNick, "contains illegal characters"}, + }} + } + if casemapASCII(nick) == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"}, + }} + } + + var err error + dc.forEachNetwork(func(n *network) { + if err != nil || (upstream != nil && upstream.network != n) { + return + } + n.Nick = nick + err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network) + }) + if err != nil { + return err + } + + dc.forEachUpstream(func(uc *upstreamConn) { + if upstream != nil && upstream != uc { + return + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NICK", + Params: []string{nick}, + }) + }) + + if dc.upstream() == nil && upstream == nil && dc.nick != nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{nick}, + }) + dc.nick = nick + dc.nickCM = casemapASCII(dc.nick) + } + case "SETNAME": + var realname string + if err := parseMessageParams(msg, &realname); err != nil { + return err + } + + // If the client just resets to the default, just wipe the per-network + // preference + storeRealname := realname + if realname == dc.user.Realname { + storeRealname = "" + } + + var storeErr error + var needUpdate []Network + dc.forEachNetwork(func(n *network) { + // We only need to call updateNetwork for upstreams that don't + // support setname + if uc := n.conn; uc != nil && uc.caps["setname"] { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "SETNAME", + Params: []string{realname}, + }) + + n.Realname = storeRealname + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil { + dc.logger.Printf("failed to store network realname: %v", err) + storeErr = err + } + return + } + + record := n.Network // copy network record because we'll mutate it + record.Realname = storeRealname + needUpdate = append(needUpdate, record) + }) + + // Walk the network list as a second step, because updateNetwork + // mutates the original list + for _, record := range needUpdate { + if _, err := dc.user.updateNetwork(ctx, &record); err != nil { + dc.logger.Printf("failed to update network realname: %v", err) + storeErr = err + } + } + if storeErr != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"}, + }} + } + + if dc.upstream() == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{realname}, + }) + } + case "JOIN": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var keys []string + if len(msg.Params) > 1 { + keys = strings.Split(msg.Params[1], ",") + } + + for i, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + var key string + if len(keys) > i { + key = keys[i] + } + + if !uc.isChannel(upstreamName) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, "Not a channel name"}, + }) + continue + } + + // Most servers ignore duplicate JOIN messages. We ignore them here + // because some clients automatically send JOIN messages in bulk + // when reconnecting to the bouncer. We don't want to flood the + // upstream connection with these. + if !uc.channels.Has(upstreamName) { + params := []string{upstreamName} + if key != "" { + params = append(params, key) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "JOIN", + Params: params, + }) + } + + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + // Don't clear the channel key if there's one set + // TODO: add a way to unset the channel key + if key != "" { + ch.Key = key + } + uc.network.attach(ctx, ch) + } else { + ch = &Channel{ + Name: upstreamName, + Key: key, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } + case "PART": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + for _, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if strings.EqualFold(reason, "detach") { + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + uc.network.detach(ch) + } else { + ch = &Channel{ + Name: name, + Detached: true, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } else { + params := []string{upstreamName} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "PART", + Params: params, + }) + + if err := uc.network.deleteChannel(ctx, upstreamName); err != nil { + dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) + } + } + } + case "KICK": + var channelStr, userStr string + if err := parseMessageParams(msg, &channelStr, &userStr); err != nil { + return err + } + + channels := strings.Split(channelStr, ",") + users := strings.Split(userStr, ",") + + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + + if len(channels) != 1 && len(channels) != len(users) { + return ircError{&irc.Message{ + Command: irc.ERR_BADCHANMASK, + Params: []string{dc.nick, channelStr, "Bad channel mask"}, + }} + } + + for i, user := range users { + var channel string + if len(channels) == 1 { + channel = channels[0] + } else { + channel = channels[i] + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + params := []string{upstreamChannel, upstreamUser} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "KICK", + Params: params, + }) + } + case "MODE": + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + + var modeStr string + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + if casemapASCII(name) == dc.nickCM { + if modeStr != "" { + if uc := dc.upstream(); uc != nil { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: []string{uc.nick, modeStr}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_UMODEUNKNOWNFLAG, + Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"}, + }) + } + } else { + var userMode string + if uc := dc.upstream(); uc != nil { + userMode = string(uc.modes) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + userMode}, + }) + } + return nil + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if !uc.isChannel(upstreamName) { + return ircError{&irc.Message{ + Command: irc.ERR_USERSDONTMATCH, + Params: []string{dc.nick, "Cannot change mode for other users"}, + }} + } + + if modeStr != "" { + params := []string{upstreamName, modeStr} + params = append(params, msg.Params[2:]...) + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: params, + }) + } else { + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "No such channel"}, + }} + } + + if ch.modes == nil { + // we haven't received the initial RPL_CHANNELMODEIS yet + // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS + return nil + } + + modeStr, modeParams := ch.modes.Format() + params := []string{dc.nick, name, modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + if ch.creationTime != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, name, ch.creationTime}, + }) + } + } + case "TOPIC": + var channel string + if err := parseMessageParams(msg, &channel); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + if len(msg.Params) > 1 { // setting topic + topic := msg.Params[1] + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "TOPIC", + Params: []string{upstreamName, topic}, + }) + } else { // getting topic + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, upstreamName, "No such channel"}, + }} + } + sendTopic(dc, ch) + } + case "LIST": + network := dc.network + if network == nil && len(msg.Params) > 0 { + var err error + network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0]) + if err != nil { + return err + } + } + if network == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"}, + }) + return nil + } + + uc := network.conn + if uc == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Disconnected from upstream server"}, + }) + return nil + } + + uc.enqueueCommand(dc, msg) + case "NAMES": + if len(msg.Params) == 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, "*", "End of /NAMES list"}, + }) + return nil + } + + channels := strings.Split(msg.Params[0], ",") + for _, channel := range channels { + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ch := uc.channels.Value(upstreamName) + if ch != nil { + sendNames(dc, ch) + } else { + // NAMES on a channel we have not joined, ask upstream + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NAMES", + Params: []string{upstreamName}, + }) + } + } + // For WHOX docs, see: + // - http://faerion.sourceforge.net/doc/irc/whox.var + // - https://github.com/quakenet/snircd/blob/master/doc/readme.who + // Note, many features aren't widely implemented, such as flags and mask2 + case "WHO": + if len(msg.Params) == 0 { + // TODO: support WHO without parameters + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, "*", "End of /WHO list"}, + }) + return nil + } + + // Clients will use the first mask to match RPL_ENDOFWHO + endOfWhoToken := msg.Params[0] + + // TODO: add support for WHOX mask2 + mask := msg.Params[0] + var options string + if len(msg.Params) > 1 { + options = msg.Params[1] + } + + optionsParts := strings.SplitN(options, "%", 2) + // TODO: add support for WHOX flags in optionsParts[0] + var fields, whoxToken string + if len(optionsParts) == 2 { + optionsParts := strings.SplitN(optionsParts[1], ",", 2) + fields = strings.ToLower(optionsParts[0]) + if len(optionsParts) == 2 && strings.Contains(fields, "t") { + whoxToken = optionsParts[1] + } + } + + // TODO: support mixed bouncer/upstream WHO queries + maskCM := casemapASCII(mask) + if dc.network == nil && maskCM == dc.nickCM { + // TODO: support AWAY (H/G) in self WHO reply + flags := "H" + if dc.user.Admin { + flags += "*" + } + info := whoxInfo{ + Token: whoxToken, + Username: dc.user.Username, + Hostname: dc.hostname, + Server: dc.srv.Config().Hostname, + Nickname: dc.nick, + Flags: flags, + Account: dc.user.Username, + Realname: dc.realname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + if maskCM == serviceNickCM { + info := whoxInfo{ + Token: whoxToken, + Username: servicePrefix.User, + Hostname: servicePrefix.Host, + Server: dc.srv.Config().Hostname, + Nickname: serviceNick, + Flags: "H*", + Account: serviceNick, + Realname: serviceRealname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + + // TODO: properly support WHO masks + uc, upstreamMask, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + params := []string{upstreamMask} + if options != "" { + params = append(params, options) + } + + uc.enqueueCommand(dc, &irc.Message{ + Command: "WHO", + Params: params, + }) + case "WHOIS": + if len(msg.Params) == 0 { + return ircError{&irc.Message{ + Command: irc.ERR_NONICKNAMEGIVEN, + Params: []string{dc.nick, "No nickname given"}, + }} + } + + var target, mask string + if len(msg.Params) == 1 { + target = "" + mask = msg.Params[0] + } else { + target = msg.Params[0] + mask = msg.Params[1] + } + // TODO: support multiple WHOIS users + if i := strings.IndexByte(mask, ','); i >= 0 { + mask = mask[:i] + } + + if dc.network == nil && casemapASCII(mask) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"}, + }) + if dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, dc.nick, "is a bouncer administrator"}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, dc.nick, "End of /WHOIS list"}, + }) + return nil + } + if casemapASCII(mask) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, serviceNick, "is the bouncer service"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, serviceNick, "End of /WHOIS list"}, + }) + return nil + } + + // TODO: support WHOIS masks + uc, upstreamNick, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + var params []string + if target != "" { + if target == mask { // WHOIS nick nick + params = []string{upstreamNick, upstreamNick} + } else { + params = []string{target, upstreamNick} + } + } else { + params = []string{upstreamNick} + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "WHOIS", + Params: params, + }) + case "PRIVMSG", "NOTICE": + var targetsStr, text string + if err := parseMessageParams(msg, &targetsStr, &text); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) { + // "$" means a server mask follows. If it's the bouncer's + // hostname, broadcast the message to all bouncer users. + if !dc.user.Admin { + return ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_BADMASK, + Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"}, + }} + } + + dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text) + + broadcastTags := tags.Copy() + broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now())) + broadcastMsg := &irc.Message{ + Tags: broadcastTags, + Prefix: servicePrefix, + Command: msg.Command, + Params: []string{name, text}, + } + dc.srv.forEachUser(func(u *user) { + u.events <- eventBroadcast{broadcastMsg} + }) + continue + } + + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + continue + } + + if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM { + if dc.caps["echo-message"] { + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + dc.SendMessage(&irc.Message{ + Tags: echoTags, + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + } + handleServicePRIVMSG(ctx, dc, text) + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" { + dc.handleNickServPRIVMSG(ctx, uc, text) + } + + unmarshaledText := text + if uc.isChannel(upstreamName) { + unmarshaledText = dc.unmarshalText(uc, text) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: msg.Command, + Params: []string{upstreamName, unmarshaledText}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: msg.Command, + Params: []string{upstreamName, text}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "TAGMSG": + var targetsStr string + if err := parseMessageParams(msg, &targetsStr); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: "TAGMSG", + Params: []string{name}, + }) + continue + } + + if casemapASCII(name) == serviceNickCM { + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + if _, ok := uc.caps["message-tags"]; !ok { + continue + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: "TAGMSG", + Params: []string{upstreamName}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: "TAGMSG", + Params: []string{upstreamName}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "INVITE": + var user, channel string + if err := parseMessageParams(msg, &user, &channel); err != nil { + return err + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "INVITE", + Params: []string{upstreamUser, upstreamChannel}, + }) + case "AUTHENTICATE": + // Post-connection-registration AUTHENTICATE is unsupported in + // multi-upstream mode, or if the upstream doesn't support SASL + uc := dc.upstream() + if uc == nil || !uc.caps["sasl"] { + return ircError{&irc.Message{ + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Upstream network authentication not supported"}, + }} + } + + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } + + if credentials != nil { + if uc.saslClient != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Another authentication attempt is already in progress"}, + }) + return nil + } + + uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername) + uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword) + uc.enqueueCommand(dc, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"PLAIN"}, + }) + } + case "REGISTER", "VERIFY": + // Check number of params here, since we'll use that to save the + // credentials on command success + if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) { + return newNeedMoreParamsError(msg.Command) + } + + uc := dc.upstream() + if uc == nil || !uc.caps["draft/account-registration"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, + }} + } + + uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0]) + uc.enqueueCommand(dc, msg) + case "MONITOR": + // MONITOR is unsupported in multi-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + if _, ok := uc.isupport["MONITOR"]; !ok { + return newUnknownCommandError(msg.Command) + } + + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "+", "-": + var targets string + if err := parseMessageParams(msg, nil, &targets); err != nil { + return err + } + for _, target := range strings.Split(targets, ",") { + if subcommand == "+" { + // Hard limit, just to avoid having downstreams fill our map + if len(dc.monitored.innerMap) >= 1000 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_MONLISTFULL, + Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"}, + }) + continue + } + + dc.monitored.SetValue(target, nil) + + if uc.monitored.Has(target) { + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } else { + dc.monitored.Delete(target) + } + } + uc.updateMonitor() + case "C": // clear + dc.monitored = newCasemapMap(0) + uc.updateMonitor() + case "L": // list + // TODO: be less lazy and pack the list + for _, entry := range dc.monitored.innerMap { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MONLIST, + Params: []string{dc.nick, entry.originalKey}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFMONLIST, + Params: []string{dc.nick, "End of MONITOR list"}, + }) + case "S": // status + // TODO: be less lazy and pack the lists + for _, entry := range dc.monitored.innerMap { + target := entry.originalKey + + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } + case "CHATHISTORY": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + var target, limitStr string + var boundsStr [2]string + switch subcommand { + case "AFTER", "BEFORE", "LATEST": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil { + return err + } + case "BETWEEN": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + case "TARGETS": + if dc.network == nil { + // Either an unbound bouncer network, in which case we should return no targets, + // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet. + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {}) + return nil + } + if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + default: + // TODO: support AROUND + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"}, + }} + } + + // We don't save history for our service + if casemapASCII(target) == serviceNickCM { + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {}) + return nil + } + + store, ok := dc.user.msgStore.(chatHistoryMessageStore) + if !ok { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{dc.nick, "CHATHISTORY", "Unknown command"}, + }} + } + + network, entity, err := dc.unmarshalEntityNetwork(target) + if err != nil { + return err + } + entity = network.casemap(entity) + + // TODO: support msgid criteria + var bounds [2]time.Time + bounds[0] = parseChatHistoryBound(boundsStr[0]) + if subcommand == "LATEST" && boundsStr[0] == "*" { + bounds[0] = time.Now() + } else if bounds[0].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"}, + }} + } + + if boundsStr[1] != "" { + bounds[1] = parseChatHistoryBound(boundsStr[1]) + if bounds[1].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"}, + }} + } + } + + limit, err := strconv.Atoi(limitStr) + if err != nil || limit < 0 || limit > chatHistoryLimit { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"}, + }} + } + + eventPlayback := dc.caps["draft/event-playback"] + + var history []*irc.Message + switch subcommand { + case "BEFORE", "LATEST": + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) + case "AFTER": + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) + case "BETWEEN": + if bounds[0].Before(bounds[1]) { + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } else { + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } + case "TARGETS": + // TODO: support TARGETS in multi-upstream mode + targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback) + if err != nil { + dc.logger.Printf("failed fetching targets for chathistory: %v", err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"}, + }} + } + + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) { + for _, target := range targets { + if ch := network.channels.Value(target.Name); ch != nil && ch.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "CHATHISTORY", + Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)}, + }) + } + }) + + return nil + } + if err != nil { + dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err) + return newChatHistoryError(subcommand, target) + } + + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, network)) + } + }) + case "READ": + var target, criteria string + if err := parseMessageParams(msg, &target); err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"}, + }} + } + if len(msg.Params) > 1 { + criteria = msg.Params[1] + } + + // We don't save read receipts for our service + if casemapASCII(target) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{target, "*"}, + }) + return nil + } + + uc, entity, err := dc.unmarshalEntity(target) + if err != nil { + return err + } + entityCM := uc.network.casemap(entity) + + r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } else if r == nil { + r = &ReadReceipt{ + Target: entityCM, + } + } + + broadcast := false + if len(criteria) > 0 { + // TODO: support msgid criteria + criteriaParts := strings.SplitN(criteria, "=", 2) + if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"}, + }} + } + + timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1]) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"}, + }} + } + now := time.Now() + if timestamp.After(now) { + timestamp = now + } + if r.Timestamp.Before(timestamp) { + r.Timestamp = timestamp + if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil { + dc.logger.Printf("failed to store receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } + broadcast = true + } + } + + timestampStr := "*" + if !r.Timestamp.IsZero() { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + uc.forEachDownstream(func(d *downstreamConn) { + if broadcast || dc.id == d.id { + d.SendMessage(&irc.Message{ + Prefix: d.prefix(), + Command: "READ", + Params: []string{d.marshalEntity(uc.network, entity), timestampStr}, + }) + } + }) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"}, + }} + case "LISTNETWORKS": + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + case "ADDNETWORK": + var attrsStr string + if err := parseMessageParams(msg, nil, &attrsStr); err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + record := &Network{Nick: dc.nick, Enabled: true} + if err := updateNetworkAttrs(record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)}, + }) + case "CHANGENETWORK": + var idStr, attrsStr string + if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + record := net.Network // copy network record because we'll mutate it + if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + _, err = dc.user.updateNetwork(ctx, &record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"CHANGENETWORK", idStr}, + }) + case "DELNETWORK": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"DELNETWORK", idStr}, + }) + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"}, + }} + } + default: + dc.logger.Printf("unhandled message: %v", msg) + + // Only forward unknown commands in single-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + + uc.SendMessageLabeled(ctx, dc.id, msg) + } + return nil +} + +func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) { + username, password, ok := parseNickServCredentials(text, uc.nick) + if ok { + uc.network.autoSaveSASLPlain(ctx, username, password) + } +} + +func parseNickServCredentials(text, nick string) (username, password string, ok bool) { + fields := strings.Fields(text) + if len(fields) < 2 { + return "", "", false + } + cmd := strings.ToUpper(fields[0]) + params := fields[1:] + switch cmd { + case "REGISTER": + username = nick + password = params[0] + case "IDENTIFY": + if len(params) == 1 { + username = nick + password = params[0] + } else { + username = params[0] + password = params[1] + } + case "SET": + if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") { + username = nick + password = params[1] + } + default: + return "", "", false + } + return username, password, true +} diff --git a/branches/master/go.mod b/branches/master/go.mod new file mode 100644 index 0000000..cc08167 --- /dev/null +++ b/branches/master/go.mod @@ -0,0 +1,39 @@ +module marisa.chaotic.ninja/suika + +go 1.20 + +require ( + git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 + git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 + github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead + github.com/lib/pq v1.10.7 + golang.org/x/crypto v0.7.0 + golang.org/x/term v0.6.0 + golang.org/x/time v0.3.0 + gopkg.in/irc.v3 v3.1.4 + modernc.org/sqlite v1.21.0 +) + +require ( + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/stretchr/testify v1.8.0 // indirect + golang.org/x/mod v0.3.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + lukechampine.com/uint128 v1.2.0 // indirect + modernc.org/cc/v3 v3.40.0 // indirect + modernc.org/ccgo/v3 v3.16.13 // indirect + modernc.org/libc v1.22.3 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/opt v0.1.3 // indirect + modernc.org/strutil v1.1.3 // indirect + modernc.org/token v1.0.1 // indirect +) diff --git a/branches/master/go.sum b/branches/master/go.sum new file mode 100644 index 0000000..090246d --- /dev/null +++ b/branches/master/go.sum @@ -0,0 +1,105 @@ +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao= +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U= +git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc= +gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= +modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= +modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= +modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= +modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= +modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= +modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY= +modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow= +modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI= +modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= +modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= +modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws= +modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= +modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= diff --git a/branches/master/irc.go b/branches/master/irc.go new file mode 100644 index 0000000..1799e32 --- /dev/null +++ b/branches/master/irc.go @@ -0,0 +1,811 @@ +package suika + +import ( + "fmt" + "sort" + "strings" + "time" + "unicode" + "unicode/utf8" + + "gopkg.in/irc.v3" +) + +const ( + rpl_statsping = "246" + rpl_localusers = "265" + rpl_globalusers = "266" + rpl_creationtime = "329" + rpl_topicwhotime = "333" + rpl_whospcrpl = "354" + rpl_whoisaccount = "330" + err_invalidcapcmd = "410" +) + +const ( + maxMessageLength = 512 + maxMessageParams = 15 + maxSASLLength = 400 +) + +// The server-time layout, as defined in the IRCv3 spec. +const serverTimeLayout = "2006-01-02T15:04:05.000Z" + +func formatServerTime(t time.Time) string { + return t.UTC().Format(serverTimeLayout) +} + +type userModes string + +func (ms userModes) Has(c byte) bool { + return strings.IndexByte(string(ms), c) >= 0 +} + +func (ms *userModes) Add(c byte) { + if !ms.Has(c) { + *ms += userModes(c) + } +} + +func (ms *userModes) Del(c byte) { + i := strings.IndexByte(string(*ms), c) + if i >= 0 { + *ms = (*ms)[:i] + (*ms)[i+1:] + } +} + +func (ms *userModes) Apply(s string) error { + var plusMinus byte + for i := 0; i < len(s); i++ { + switch c := s[i]; c { + case '+', '-': + plusMinus = c + default: + switch plusMinus { + case '+': + ms.Add(c) + case '-': + ms.Del(c) + default: + return fmt.Errorf("malformed modestring %q: missing plus/minus", s) + } + } + } + return nil +} + +type channelModeType byte + +// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message +const ( + // modes that add or remove an address to or from a list + modeTypeA channelModeType = iota + // modes that change a setting on a channel, and must always have a parameter + modeTypeB + // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset + modeTypeC + // modes that change a setting on a channel, and must not have a parameter + modeTypeD +) + +var stdChannelModes = map[byte]channelModeType{ + 'b': modeTypeA, // ban list + 'e': modeTypeA, // ban exception list + 'I': modeTypeA, // invite exception list + 'k': modeTypeB, // channel key + 'l': modeTypeC, // channel user limit + 'i': modeTypeD, // channel is invite-only + 'm': modeTypeD, // channel is moderated + 'n': modeTypeD, // channel has no external messages + 's': modeTypeD, // channel is secret + 't': modeTypeD, // channel has protected topic +} + +type channelModes map[byte]string + +// applyChannelModes parses a mode string and mode arguments from a MODE message, +// and applies the corresponding channel mode and user membership changes on that channel. +// +// If ch.modes is nil, channel modes are not updated. +// +// needMarshaling is a list of indexes of mode arguments that represent entities +// that must be marshaled when sent downstream. +func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) { + needMarshaling = make(map[int]struct{}, len(arguments)) + nextArgument := 0 + var plusMinus byte +outer: + for i := 0; i < len(modeStr); i++ { + mode := modeStr[i] + if mode == '+' || mode == '-' { + plusMinus = mode + continue + } + if plusMinus != '+' && plusMinus != '-' { + return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr) + } + + for _, membership := range ch.conn.availableMemberships { + if membership.Mode == mode { + if nextArgument >= len(arguments) { + return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) + } + member := arguments[nextArgument] + m := ch.Members.Value(member) + if m != nil { + if plusMinus == '+' { + m.Add(ch.conn.availableMemberships, membership) + } else { + // TODO: for upstreams without multi-prefix, query the user modes again + m.Remove(membership) + } + } + needMarshaling[nextArgument] = struct{}{} + nextArgument++ + continue outer + } + } + + mt, ok := ch.conn.availableChannelModes[mode] + if !ok { + continue + } + if mt == modeTypeA { + nextArgument++ + } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') { + if plusMinus == '+' { + var argument string + // some sentitive arguments (such as channel keys) can be omitted for privacy + // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages) + if nextArgument < len(arguments) { + argument = arguments[nextArgument] + } + if ch.modes != nil { + ch.modes[mode] = argument + } + } else { + delete(ch.modes, mode) + } + nextArgument++ + } else if mt == modeTypeC || mt == modeTypeD { + if plusMinus == '+' { + if ch.modes != nil { + ch.modes[mode] = "" + } + } else { + delete(ch.modes, mode) + } + } + } + return needMarshaling, nil +} + +func (cm channelModes) Format() (modeString string, parameters []string) { + var modesWithValues strings.Builder + var modesWithoutValues strings.Builder + parameters = make([]string, 0, 16) + for mode, value := range cm { + if value != "" { + modesWithValues.WriteString(string(mode)) + parameters = append(parameters, value) + } else { + modesWithoutValues.WriteString(string(mode)) + } + } + modeString = "+" + modesWithValues.String() + modesWithoutValues.String() + return +} + +const stdChannelTypes = "#&+!" + +type channelStatus byte + +const ( + channelPublic channelStatus = '=' + channelSecret channelStatus = '@' + channelPrivate channelStatus = '*' +) + +func parseChannelStatus(s string) (channelStatus, error) { + if len(s) > 1 { + return 0, fmt.Errorf("invalid channel status %q: more than one character", s) + } + switch cs := channelStatus(s[0]); cs { + case channelPublic, channelSecret, channelPrivate: + return cs, nil + default: + return 0, fmt.Errorf("invalid channel status %q: unknown status", s) + } +} + +type membership struct { + Mode byte + Prefix byte +} + +var stdMemberships = []membership{ + {'q', '~'}, // founder + {'a', '&'}, // protected + {'o', '@'}, // operator + {'h', '%'}, // halfop + {'v', '+'}, // voice +} + +// memberships always sorted by descending membership rank +type memberships []membership + +func (m *memberships) Add(availableMemberships []membership, newMembership membership) { + l := *m + i := 0 + for _, availableMembership := range availableMemberships { + if i >= len(l) { + break + } + if l[i] == availableMembership { + if availableMembership == newMembership { + // we already have this membership + return + } + i++ + continue + } + if availableMembership == newMembership { + break + } + } + // insert newMembership at i + l = append(l, membership{}) + copy(l[i+1:], l[i:]) + l[i] = newMembership + *m = l +} + +func (m *memberships) Remove(oldMembership membership) { + l := *m + for i, currentMembership := range l { + if currentMembership == oldMembership { + *m = append(l[:i], l[i+1:]...) + return + } + } +} + +func (m memberships) Format(dc *downstreamConn) string { + if !dc.caps["multi-prefix"] { + if len(m) == 0 { + return "" + } + return string(m[0].Prefix) + } + prefixes := make([]byte, len(m)) + for i, membership := range m { + prefixes[i] = membership.Prefix + } + return string(prefixes) +} + +func parseMessageParams(msg *irc.Message, out ...*string) error { + if len(msg.Params) < len(out) { + return newNeedMoreParamsError(msg.Command) + } + for i := range out { + if out[i] != nil { + *out[i] = msg.Params[i] + } + } + return nil +} + +func copyClientTags(tags irc.Tags) irc.Tags { + t := make(irc.Tags, len(tags)) + for k, v := range tags { + if strings.HasPrefix(k, "+") { + t[k] = v + } + } + return t +} + +type batch struct { + Type string + Params []string + Outer *batch // if not-nil, this batch is nested in Outer + Label string +} + +func join(channels, keys []string) []*irc.Message { + // Put channels with a key first + js := joinSorter{channels, keys} + sort.Sort(&js) + + // Two spaces because there are three words (JOIN, channels and keys) + maxLength := maxMessageLength - (len("JOIN") + 2) + + var msgs []*irc.Message + var channelsBuf, keysBuf strings.Builder + for i, channel := range channels { + key := keys[i] + + n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel) + if key != "" { + n += 1 + len(key) + } + + if channelsBuf.Len() > 0 && n > maxLength { + // No room for the new channel in this message + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + channelsBuf.Reset() + keysBuf.Reset() + } + + if channelsBuf.Len() > 0 { + channelsBuf.WriteByte(',') + } + channelsBuf.WriteString(channel) + if key != "" { + if keysBuf.Len() > 0 { + keysBuf.WriteByte(',') + } + keysBuf.WriteString(key) + } + } + if channelsBuf.Len() > 0 { + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + } + + return msgs +} + +func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message { + maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text + + var msgs []*irc.Message + for len(tokens) > 0 { + var msgTokens []string + if len(tokens) > maxTokens { + msgTokens = tokens[:maxTokens] + tokens = tokens[maxTokens:] + } else { + msgTokens = tokens + tokens = nil + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ISUPPORT, + Params: append(append([]string{nick}, msgTokens...), "are supported"), + }) + } + + return msgs +} + +func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message { + var msgs []*irc.Message + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTDSTART, + Params: []string{nick, fmt.Sprintf("- Message of the Day -")}, + }) + + for _, l := range strings.Split(motd, "\n") { + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTD, + Params: []string{nick, l}, + }) + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ENDOFMOTD, + Params: []string{nick, "End of /MOTD command."}, + }) + + return msgs +} + +func generateMonitor(subcmd string, targets []string) []*irc.Message { + maxLength := maxMessageLength - len("MONITOR "+subcmd+" ") + + var msgs []*irc.Message + var buf []string + n := 0 + for _, target := range targets { + if n+len(target)+1 > maxLength { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + buf = buf[:0] + n = 0 + } + + buf = append(buf, target) + n += len(target) + 1 + } + + if len(buf) > 0 { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + } + + return msgs +} + +type joinSorter struct { + channels []string + keys []string +} + +func (js *joinSorter) Len() int { + return len(js.channels) +} + +func (js *joinSorter) Less(i, j int) bool { + if (js.keys[i] != "") != (js.keys[j] != "") { + // Only one of the channels has a key + return js.keys[i] != "" + } + return js.channels[i] < js.channels[j] +} + +func (js *joinSorter) Swap(i, j int) { + js.channels[i], js.channels[j] = js.channels[j], js.channels[i] + js.keys[i], js.keys[j] = js.keys[j], js.keys[i] +} + +// parseCTCPMessage parses a CTCP message. CTCP is defined in +// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02 +func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) { + if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 { + return "", "", false + } + text := msg.Params[1] + + if !strings.HasPrefix(text, "\x01") { + return "", "", false + } + text = strings.Trim(text, "\x01") + + words := strings.SplitN(text, " ", 2) + cmd = strings.ToUpper(words[0]) + if len(words) > 1 { + params = words[1] + } + + return cmd, params, true +} + +type casemapping func(string) string + +func casemapNone(name string) string { + return name +} + +// CasemapASCII of name is the canonical representation of name according to the +// ascii casemapping. +func casemapASCII(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } + } + return string(nameBytes) +} + +// casemapRFC1459 of name is the canonical representation of name according to the +// rfc1459 casemapping. +func casemapRFC1459(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } else if r == '~' { + nameBytes[i] = '^' + } + } + return string(nameBytes) +} + +// casemapRFC1459Strict of name is the canonical representation of name +// according to the rfc1459-strict casemapping. +func casemapRFC1459Strict(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } + } + return string(nameBytes) +} + +func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) { + switch tokenValue { + case "ascii": + casemap = casemapASCII + case "rfc1459": + casemap = casemapRFC1459 + case "rfc1459-strict": + casemap = casemapRFC1459Strict + default: + return nil, false + } + return casemap, true +} + +func partialCasemap(higher casemapping, name string) string { + nameFullyCM := []byte(higher(name)) + nameBytes := []byte(name) + for i, r := range nameBytes { + if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') { + nameBytes[i] = nameFullyCM[i] + } + } + return string(nameBytes) +} + +type casemapMap struct { + innerMap map[string]casemapEntry + casemap casemapping +} + +type casemapEntry struct { + originalKey string + value interface{} +} + +func newCasemapMap(size int) casemapMap { + return casemapMap{ + innerMap: make(map[string]casemapEntry, size), + casemap: casemapNone, + } +} + +func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return "", false + } + return entry.originalKey, true +} + +func (cm *casemapMap) Has(name string) bool { + _, ok := cm.innerMap[cm.casemap(name)] + return ok +} + +func (cm *casemapMap) Len() int { + return len(cm.innerMap) +} + +func (cm *casemapMap) SetValue(name string, value interface{}) { + nameCM := cm.casemap(name) + entry, ok := cm.innerMap[nameCM] + if !ok { + cm.innerMap[nameCM] = casemapEntry{ + originalKey: name, + value: value, + } + return + } + entry.value = value + cm.innerMap[nameCM] = entry +} + +func (cm *casemapMap) Delete(name string) { + delete(cm.innerMap, cm.casemap(name)) +} + +func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { + cm.casemap = newCasemap + newInnerMap := make(map[string]casemapEntry, len(cm.innerMap)) + for _, entry := range cm.innerMap { + newInnerMap[cm.casemap(entry.originalKey)] = entry + } + cm.innerMap = newInnerMap +} + +type upstreamChannelCasemapMap struct{ casemapMap } + +func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*upstreamChannel) +} + +type channelCasemapMap struct{ casemapMap } + +func (cm *channelCasemapMap) Value(name string) *Channel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*Channel) +} + +type membershipsCasemapMap struct{ casemapMap } + +func (cm *membershipsCasemapMap) Value(name string) *memberships { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*memberships) +} + +type deliveredCasemapMap struct{ casemapMap } + +func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(deliveredClientMap) +} + +type monitorCasemapMap struct{ casemapMap } + +func (cm *monitorCasemapMap) Value(name string) (online bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return false + } + return entry.value.(bool) +} + +func isWordBoundary(r rune) bool { + switch r { + case '-', '_', '|': // inspired from weechat.look.highlight_regex + return false + default: + return !unicode.IsLetter(r) && !unicode.IsNumber(r) + } +} + +func isHighlight(text, nick string) bool { + for { + i := strings.Index(text, nick) + if i < 0 { + return false + } + + left, _ := utf8.DecodeLastRuneInString(text[:i]) + right, _ := utf8.DecodeRuneInString(text[i+len(nick):]) + if isWordBoundary(left) && isWordBoundary(right) { + return true + } + + text = text[i+len(nick):] + } +} + +// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound. +// The zero time is returned on error. +func parseChatHistoryBound(param string) time.Time { + parts := strings.SplitN(param, "=", 2) + if len(parts) != 2 { + return time.Time{} + } + switch parts[0] { + case "timestamp": + timestamp, err := time.Parse(serverTimeLayout, parts[1]) + if err != nil { + return time.Time{} + } + return timestamp + default: + return time.Time{} + } +} + +// whoxFields is the list of all WHOX field letters, by order of appearance in +// RPL_WHOSPCRPL messages. +var whoxFields = []byte("tcuihsnfdlaor") + +type whoxInfo struct { + Token string + Username string + Hostname string + Server string + Nickname string + Flags string + Account string + Realname string +} + +func (info *whoxInfo) get(field byte) string { + switch field { + case 't': + return info.Token + case 'c': + return "*" + case 'u': + return info.Username + case 'i': + return "255.255.255.255" + case 'h': + return info.Hostname + case 's': + return info.Server + case 'n': + return info.Nickname + case 'f': + return info.Flags + case 'd': + return "0" + case 'l': // idle time + return "0" + case 'a': + account := "0" // WHOX uses "0" to mean "no account" + if info.Account != "" && info.Account != "*" { + account = info.Account + } + return account + case 'o': + return "0" + case 'r': + return info.Realname + } + return "" +} + +func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message { + if fields == "" { + return &irc.Message{ + Prefix: prefix, + Command: irc.RPL_WHOREPLY, + Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname}, + } + } + + fieldSet := make(map[byte]bool) + for i := 0; i < len(fields); i++ { + fieldSet[fields[i]] = true + } + + var values []string + for _, field := range whoxFields { + if !fieldSet[field] { + continue + } + values = append(values, info.get(field)) + } + + return &irc.Message{ + Prefix: prefix, + Command: rpl_whospcrpl, + Params: append([]string{nick}, values...), + } +} + +var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C") + +func encodeISUPPORT(s string) string { + return isupportEncoder.Replace(s) +} diff --git a/branches/master/irc_test.go b/branches/master/irc_test.go new file mode 100644 index 0000000..8757bc8 --- /dev/null +++ b/branches/master/irc_test.go @@ -0,0 +1,34 @@ +package suika + +import ( + "testing" +) + +func TestIsHighlight(t *testing.T) { + nick := "SojuUser" + testCases := []struct { + name string + text string + hl bool + }{ + {"noContains", "hi there Soju User!", false}, + {"middle", "hi there SojuUser!", true}, + {"start", "SojuUser: how are you doing?", true}, + {"end", "maybe ask SojuUser", true}, + {"inWord", "but OtherSojuUserSan is a different nick", false}, + {"startWord", "and OtherSojuUser is another different nick", false}, + {"endWord", "and SojuUserSan is yet a different nick", false}, + {"underscore", "and SojuUser_san has nothing to do with me", false}, + {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false}, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + hl := isHighlight(tc.text, nick) + if hl != tc.hl { + t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl) + } + }) + } +} diff --git a/branches/master/msgstore.go b/branches/master/msgstore.go new file mode 100644 index 0000000..ed57429 --- /dev/null +++ b/branches/master/msgstore.go @@ -0,0 +1,123 @@ +package suika + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +// messageStore is a per-user store for IRC messages. +type messageStore interface { + Close() error + // LastMsgID queries the last message ID for the given network, entity and + // date. The message ID returned may not refer to a valid message, but can be + // used in history queries. + LastMsgID(network *Network, entity string, t time.Time) (string, error) + // LoadLatestID queries the latest non-event messages for the given network, + // entity and date, up to a count of limit messages, sorted from oldest to newest. + LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) + Append(network *Network, entity string, msg *irc.Message) (id string, err error) +} + +type chatHistoryTarget struct { + Name string + LatestMessage time.Time +} + +// chatHistoryMessageStore is a message store that supports chat history +// operations. +type chatHistoryMessageStore interface { + messageStore + + // ListTargets lists channels and nicknames by time of the latest message. + // It returns up to limit targets, starting from start and ending on end, + // both excluded. end may be before or after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) + // LoadBeforeTime loads up to limit messages before start down to end. The + // returned messages must be between and excluding the provided bounds. + // end is before start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + // LoadBeforeTime loads up to limit messages after start up to end. The + // returned messages must be between and excluding the provided bounds. + // end is after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) +} + +type msgIDType uint + +const ( + msgIDNone msgIDType = iota + msgIDMemory + msgIDFS +) + +const msgIDVersion uint = 0 + +type msgIDHeader struct { + Version uint + Network bare.Int + Target string + Type msgIDType +} + +type msgIDBody interface { + msgIDType() msgIDType +} + +func formatMsgID(netID int64, target string, body msgIDBody) string { + var buf bytes.Buffer + w := bare.NewWriter(&buf) + + header := msgIDHeader{ + Version: msgIDVersion, + Network: bare.Int(netID), + Target: target, + Type: body.msgIDType(), + } + if err := bare.MarshalWriter(w, &header); err != nil { + panic(err) + } + if err := bare.MarshalWriter(w, body); err != nil { + panic(err) + } + return base64.RawURLEncoding.EncodeToString(buf.Bytes()) +} + +func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + r := bare.NewReader(bytes.NewReader(b)) + + var header msgIDHeader + if err := bare.UnmarshalBareReader(r, &header); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + if header.Version != msgIDVersion { + return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion) + } + + if body != nil { + typ := body.msgIDType() + if header.Type != typ { + return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ) + } + + if err := bare.UnmarshalBareReader(r, body); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + } + + return int64(header.Network), header.Target, nil +} diff --git a/branches/master/msgstore_fs.go b/branches/master/msgstore_fs.go new file mode 100644 index 0000000..90bd76a --- /dev/null +++ b/branches/master/msgstore_fs.go @@ -0,0 +1,698 @@ +package suika + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const ( + fsMessageStoreMaxFiles = 20 + fsMessageStoreMaxTries = 100 +) + +func escapeFilename(unsafe string) (safe string) { + if unsafe == "." { + return "-" + } else if unsafe == ".." { + return "--" + } else { + return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe) + } +} + +type date struct { + Year, Month, Day int +} + +func newDate(t time.Time) date { + year, month, day := t.Date() + return date{year, int(month), day} +} + +func (d date) Time() time.Time { + return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local) +} + +type fsMsgID struct { + Date date + Offset bare.Int +} + +func (fsMsgID) msgIDType() msgIDType { + return msgIDFS +} + +func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) { + var id fsMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", time.Time{}, 0, err + } + return netID, entity, id.Date.Time(), int64(id.Offset), nil +} + +func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string { + id := fsMsgID{ + Date: newDate(t), + Offset: bare.Int(offset), + } + return formatMsgID(netID, entity, &id) +} + +type fsMessageStoreFile struct { + *os.File + lastUse time.Time +} + +// fsMessageStore is a per-user on-disk store for IRC messages. +// +// It mimicks the ZNC log layout and format. See the ZNC source: +// https://github.com/znc/znc/blob/master/modules/log.cpp +type fsMessageStore struct { + root string + user *User + + // Write-only files used by Append + files map[string]*fsMessageStoreFile // indexed by entity +} + +var _ messageStore = (*fsMessageStore)(nil) +var _ chatHistoryMessageStore = (*fsMessageStore)(nil) + +func newFSMessageStore(root string, user *User) *fsMessageStore { + return &fsMessageStore{ + root: filepath.Join(root, escapeFilename(user.Username)), + user: user, + files: make(map[string]*fsMessageStoreFile), + } +} + +func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string { + year, month, day := t.Date() + filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) + return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) +} + +// nextMsgID queries the message ID for the next message to be written to f. +func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) { + offset, err := f.Seek(0, io.SeekEnd) + if err != nil { + return "", fmt.Errorf("failed to query next FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, offset), nil +} + +func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + p := ms.logPath(network, entity, t) + fi, err := os.Stat(p) + if os.IsNotExist(err) { + return formatFSMsgID(network.ID, entity, t, -1), nil + } else if err != nil { + return "", fmt.Errorf("failed to query last FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil +} + +func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + s := formatMessage(msg) + if s == "" { + return "", nil + } + + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(serverTimeLayout, string(tag)) + if err != nil { + return "", fmt.Errorf("failed to parse message time tag: %v", err) + } + t = t.In(time.Local) + } else { + t = time.Now() + } + + f := ms.files[entity] + + // TODO: handle non-monotonic clock behaviour + path := ms.logPath(network, entity, t) + if f == nil || f.Name() != path { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0750); err != nil { + return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err) + } + + ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640) + if err != nil { + return "", fmt.Errorf("failed to open message log file %q: %v", path, err) + } + + if f != nil { + f.Close() + } + f = &fsMessageStoreFile{File: ff} + ms.files[entity] = f + } + + f.lastUse = time.Now() + + if len(ms.files) > fsMessageStoreMaxFiles { + entities := make([]string, 0, len(ms.files)) + for name := range ms.files { + entities = append(entities, name) + } + sort.Slice(entities, func(i, j int) bool { + a, b := entities[i], entities[j] + return ms.files[a].lastUse.Before(ms.files[b].lastUse) + }) + entities = entities[0 : len(entities)-fsMessageStoreMaxFiles] + for _, name := range entities { + ms.files[name].Close() + delete(ms.files, name) + } + } + + msgID, err := nextFSMsgID(network, entity, t, f.File) + if err != nil { + return "", fmt.Errorf("failed to generate message ID: %v", err) + } + + _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s) + if err != nil { + return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err) + } + + return msgID, nil +} + +func (ms *fsMessageStore) Close() error { + var closeErr error + for _, f := range ms.files { + if err := f.Close(); err != nil { + closeErr = fmt.Errorf("failed to close message store: %v", err) + } + } + return closeErr +} + +// formatMessage formats a message log line. It assumes a well-formed IRC +// message. +func formatMessage(msg *irc.Message) string { + switch strings.ToUpper(msg.Command) { + case "NICK": + return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0]) + case "JOIN": + return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host) + case "PART": + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "KICK": + nick := msg.Params[1] + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason) + case "QUIT": + var reason string + if len(msg.Params) > 0 { + reason = msg.Params[0] + } + return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "TOPIC": + var topic string + if len(msg.Params) > 1 { + topic = msg.Params[1] + } + return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic) + case "MODE": + return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " ")) + case "NOTICE": + return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1]) + case "PRIVMSG": + if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" { + return fmt.Sprintf("* %s %s", msg.Prefix.Name, params) + } else { + return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1]) + } + default: + return "" + } +} + +func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { + var hour, minute, second int + _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) + if err != nil { + return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err) + } + line = line[11:] + + var cmd string + var prefix *irc.Prefix + var params []string + if events && strings.HasPrefix(line, "*** ") { + parts := strings.SplitN(line[4:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + switch parts[0] { + case "Joins:", "Parts:", "Quits:": + args := strings.SplitN(parts[1], " ", 3) + if len(args) < 2 { + return nil, time.Time{}, nil + } + nick := args[0] + mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + maskParts := strings.SplitN(mask, "@", 2) + if len(maskParts) != 2 { + return nil, time.Time{}, nil + } + prefix = &irc.Prefix{ + Name: nick, + User: maskParts[0], + Host: maskParts[1], + } + var reason string + if len(args) > 2 { + reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")") + } + switch parts[0] { + case "Joins:": + cmd = "JOIN" + params = []string{entity} + case "Parts:": + cmd = "PART" + if reason != "" { + params = []string{entity, reason} + } else { + params = []string{entity} + } + case "Quits:": + cmd = "QUIT" + if reason != "" { + params = []string{reason} + } + } + default: + nick := parts[0] + rem := parts[1] + if r := strings.TrimPrefix(rem, "is now known as "); r != rem { + cmd = "NICK" + prefix = &irc.Prefix{ + Name: nick, + } + params = []string{r} + } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem { + args := strings.SplitN(r, " ", 2) + if len(args) != 2 { + return nil, time.Time{}, nil + } + cmd = "KICK" + prefix = &irc.Prefix{ + Name: args[0], + } + reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + params = []string{entity, nick} + if reason != "" { + params = append(params, reason) + } + } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem { + cmd = "TOPIC" + prefix = &irc.Prefix{ + Name: nick, + } + topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'") + params = []string{entity, topic} + } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem { + cmd = "MODE" + prefix = &irc.Prefix{ + Name: nick, + } + params = append([]string{entity}, strings.Split(r, " ")...) + } else { + return nil, time.Time{}, nil + } + } + } else { + var sender, text string + if strings.HasPrefix(line, "<") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[1:], "> ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "-") { + cmd = "NOTICE" + parts := strings.SplitN(line[1:], "- ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "* ") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[2:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01" + } else { + return nil, time.Time{}, nil + } + + prefix = &irc.Prefix{Name: sender} + if entity == sender { + // This is a direct message from a user to us. We don't store own + // our nickname in the logs, so grab it from the network settings. + // Not very accurate since this may not match our nick at the time + // the message was received, but we can't do a lot better. + entity = GetNick(ms.user, network) + } + params = []string{entity, text} + } + + year, month, day := ref.Date() + t := time.Date(year, month, day, hour, minute, second, 0, time.Local) + + msg := &irc.Message{ + Tags: map[string]irc.TagValue{ + "time": irc.TagValue(formatServerTime(t)), + }, + Prefix: prefix, + Command: cmd, + Params: params, + } + return msg, t, nil +} + +func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages before ref: %v", err) + } + defer f.Close() + + historyRing := make([]*irc.Message, limit) + cur := 0 + + sc := bufio.NewScanner(f) + + if afterOffset >= 0 { + if _, err := f.Seek(afterOffset, io.SeekStart); err != nil { + return nil, nil + } + sc.Scan() // skip till next newline + } + + for sc.Scan() { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(end) { + continue + } else if !t.Before(ref) { + break + } + + historyRing[cur%limit] = msg + cur++ + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err()) + } + + n := limit + if cur < limit { + n = cur + } + start := (cur - n + limit) % limit + + if start+n <= limit { // ring doesnt wrap + return historyRing[start : start+n], nil + } else { // ring wraps + history := make([]*irc.Message, n) + r := copy(history, historyRing[start:]) + copy(history[r:], historyRing[:n-r]) + return history, nil + } +} + +func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages after ref: %v", err) + } + defer f.Close() + + var history []*irc.Message + sc := bufio.NewScanner(f) + for sc.Scan() && len(history) < limit { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(ref) { + continue + } else if !t.Before(end) { + break + } + + history = append(history, msg) + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err()) + } + + return history, nil +} + +func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + history := make([]*irc.Message, limit) + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) { + buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + var history []*irc.Message + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) { + buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + history = append(history, buf...) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location()) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + return history, nil +} + +func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + var afterTime time.Time + var afterOffset int64 + if id != "" { + var idNet int64 + var idEntity string + var err error + idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id) + if err != nil { + return nil, err + } + if idNet != network.ID || idEntity != entity { + return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity") + } + } + + history := make([]*irc.Message, limit) + t := time.Now() + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) { + var offset int64 = -1 + if afterOffset >= 0 && truncateDay(t).Equal(afterTime) { + offset = afterOffset + } + + buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := t.Date() + t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { + start = start.In(time.Local) + end = end.In(time.Local) + rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) + root, err := os.Open(rootPath) + if os.IsNotExist(err) { + return nil, nil + } else if err != nil { + return nil, err + } + + // The returned targets are escaped, and there is no way to un-escape + // TODO: switch to ReadDir (Go 1.16+) + targetNames, err := root.Readdirnames(0) + root.Close() + if err != nil { + return nil, err + } + + var targets []chatHistoryTarget + for _, target := range targetNames { + // target is already escaped here + targetPath := filepath.Join(rootPath, target) + targetDir, err := os.Open(targetPath) + if err != nil { + return nil, err + } + + entries, err := targetDir.Readdir(0) + targetDir.Close() + if err != nil { + return nil, err + } + + // We use mtime here, which may give imprecise or incorrect results + var t time.Time + for _, entry := range entries { + if entry.ModTime().After(t) { + t = entry.ModTime() + } + } + + // The timestamps we get from logs have second granularity + t = truncateSecond(t) + + // Filter out targets that don't fullfil the time bounds + if !isTimeBetween(t, start, end) { + continue + } + + targets = append(targets, chatHistoryTarget{ + Name: target, + LatestMessage: t, + }) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + // Sort targets by latest message time, backwards or forwards depending on + // the order of the time bounds + sort.Slice(targets, func(i, j int) bool { + t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage + if start.Before(end) { + return t1.Before(t2) + } else { + return !t1.Before(t2) + } + }) + + // Truncate the result if necessary + if len(targets) > limit { + targets = targets[:limit] + } + + return targets, nil +} + +func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error { + oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) + newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) + // Avoid loosing data by overwriting an existing directory + if _, err := os.Stat(newDir); err == nil { + return fmt.Errorf("destination %q already exists", newDir) + } + return os.Rename(oldDir, newDir) +} + +func truncateDay(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, 0, 0, 0, 0, t.Location()) +} + +func truncateSecond(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location()) +} + +func isTimeBetween(t, start, end time.Time) bool { + if end.Before(start) { + end, start = start, end + } + return start.Before(t) && t.Before(end) +} diff --git a/branches/master/msgstore_memory.go b/branches/master/msgstore_memory.go new file mode 100644 index 0000000..25ab953 --- /dev/null +++ b/branches/master/msgstore_memory.go @@ -0,0 +1,160 @@ +package suika + +import ( + "context" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const messageRingBufferCap = 4096 + +type memoryMsgID struct { + Seq bare.Uint +} + +func (memoryMsgID) msgIDType() msgIDType { + return msgIDMemory +} + +func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) { + var id memoryMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", 0, err + } + return netID, entity, uint64(id.Seq), nil +} + +func formatMemoryMsgID(netID int64, entity string, seq uint64) string { + id := memoryMsgID{bare.Uint(seq)} + return formatMsgID(netID, entity, &id) +} + +type ringBufferKey struct { + networkID int64 + entity string +} + +type memoryMessageStore struct { + buffers map[ringBufferKey]*messageRingBuffer +} + +var _ messageStore = (*memoryMessageStore)(nil) + +func newMemoryMessageStore() *memoryMessageStore { + return &memoryMessageStore{ + buffers: make(map[ringBufferKey]*messageRingBuffer), + } +} + +func (ms *memoryMessageStore) Close() error { + ms.buffers = nil + return nil +} + +func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer { + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + return rb + } + rb := newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + return rb +} + +func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + var seq uint64 + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + seq = rb.cur + } + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + switch msg.Command { + case "PRIVMSG", "NOTICE": + // Only append these messages, because LoadLatestID shouldn't return + // other kinds of message. + default: + return "", nil + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + rb = newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + } + + seq := rb.Append(msg) + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + _, _, seq, err := parseMemoryMsgID(id) + if err != nil { + return nil, err + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + return nil, nil + } + + return rb.LoadLatestSeq(seq, limit) +} + +type messageRingBuffer struct { + buf []*irc.Message + cur uint64 +} + +func newMessageRingBuffer(capacity int) *messageRingBuffer { + return &messageRingBuffer{ + buf: make([]*irc.Message, capacity), + cur: 1, + } +} + +func (rb *messageRingBuffer) cap() uint64 { + return uint64(len(rb.buf)) +} + +func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 { + seq := rb.cur + i := int(seq % rb.cap()) + rb.buf[i] = msg + rb.cur++ + return seq +} + +func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) { + if seq > rb.cur { + return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur) + } else if seq == rb.cur { + return nil, nil + } + + // The query excludes the message with the sequence number seq + diff := rb.cur - seq - 1 + if diff > rb.cap() { + // We dropped diff - cap entries + diff = rb.cap() + } + if int(diff) > limit { + diff = uint64(limit) + } + + l := make([]*irc.Message, int(diff)) + for i := 0; i < int(diff); i++ { + j := int((rb.cur - diff + uint64(i)) % rb.cap()) + l[i] = rb.buf[j] + } + + return l, nil +} diff --git a/branches/master/net_go113.go b/branches/master/net_go113.go new file mode 100644 index 0000000..24bde7f --- /dev/null +++ b/branches/master/net_go113.go @@ -0,0 +1,12 @@ +//go:build !go1.16 +// +build !go1.16 + +package suika + +import ( + "strings" +) + +func isErrClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "use of closed network connection") +} diff --git a/branches/master/net_go116.go b/branches/master/net_go116.go new file mode 100644 index 0000000..ef1256a --- /dev/null +++ b/branches/master/net_go116.go @@ -0,0 +1,13 @@ +//go:build go1.16 +// +build go1.16 + +package suika + +import ( + "errors" + "net" +) + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) +} diff --git a/branches/master/rate.go b/branches/master/rate.go new file mode 100644 index 0000000..a14bce2 --- /dev/null +++ b/branches/master/rate.go @@ -0,0 +1,40 @@ +package suika + +import ( + "math/rand" + "time" +) + +// backoffer implements a simple exponential backoff. +type backoffer struct { + min, max, jitter time.Duration + n int64 +} + +func newBackoffer(min, max, jitter time.Duration) *backoffer { + return &backoffer{min: min, max: max, jitter: jitter} +} + +func (b *backoffer) Reset() { + b.n = 0 +} + +func (b *backoffer) Next() time.Duration { + if b.n == 0 { + b.n = 1 + return 0 + } + + d := time.Duration(b.n) * b.min + if d > b.max { + d = b.max + } else { + b.n *= 2 + } + + if b.jitter != 0 { + d += time.Duration(rand.Int63n(int64(b.jitter))) + } + + return d +} diff --git a/branches/master/rc.d/freebsd-rc.d b/branches/master/rc.d/freebsd-rc.d new file mode 100755 index 0000000..0fed078 --- /dev/null +++ b/branches/master/rc.d/freebsd-rc.d @@ -0,0 +1,29 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +desc="A drunk IRC bouncer" +rcvar="suika_enable" + +: ${suika_user="ircd"} + +command="%%PREFIX%%/bin/suika" +pidfile="/var/run/suika.pid" +required_files="%%PREFIX%%/etc/suika/config" + +start_cmd="suika_start" + +suika_start() { + /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files} +} + +load_rc_config "$name" +run_rc_command "$1" diff --git a/branches/master/rc.d/immortal.yml b/branches/master/rc.d/immortal.yml new file mode 100644 index 0000000..c7a8090 --- /dev/null +++ b/branches/master/rc.d/immortal.yml @@ -0,0 +1,3 @@ +# $TheSupernovaDuo$ +cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +user: ircd diff --git a/branches/master/rc.d/netbsd-rc.d b/branches/master/rc.d/netbsd-rc.d new file mode 100644 index 0000000..3fef470 --- /dev/null +++ b/branches/master/rc.d/netbsd-rc.d @@ -0,0 +1,28 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +rcvar="${name}" +command="%%PREFIX/bin/${name}" +command_args="--config %%PREFIX%%/etc/suika/config" +pidfile="/var/run/${name}.pid" +start_cmd="${name}_start" + +suika_start() { + printf "Starting %s..." "${name}" + ${command} ${command_args} + pgrep -n ${name} > ${pidfile} +} + +load_rc_config ${name} +run_rc_command "$1" + + diff --git a/branches/master/rc.d/openbsd-rc.d b/branches/master/rc.d/openbsd-rc.d new file mode 100644 index 0000000..f0bf77c --- /dev/null +++ b/branches/master/rc.d/openbsd-rc.d @@ -0,0 +1,12 @@ +#!/bin/ksh +# $TheSupernovaDuo$ +# vim: ft=sh + +daemon="%%PREFIX%%/bin/suika" +daemon_args="--config %%PREFIX%%/etc/suika/config" + +. /etc/rc.d/rc.subr + +rc_bg=YES + +rc_cmd "$1" diff --git a/branches/master/rc.d/suika.service b/branches/master/rc.d/suika.service new file mode 100644 index 0000000..d8d53b3 --- /dev/null +++ b/branches/master/rc.d/suika.service @@ -0,0 +1,16 @@ +# $TheSupernovaDuo$ +# vim: ft=confini +[Unit] +Description=A drunk IRC bouncer +After=network.target +Wants=network.target +StartLimitBurst=5 +StartLimitIntervalSec=1 +[Service] +Type=simple +Restart=on-abnormal +RestartSec=1 +User=suika +ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +[Install] +WantedBy=multi-user.target diff --git a/branches/master/server.go b/branches/master/server.go new file mode 100644 index 0000000..05c55fd --- /dev/null +++ b/branches/master/server.go @@ -0,0 +1,330 @@ +package suika + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "gopkg.in/irc.v3" +) + +// TODO: make configurable +var ( + retryConnectMinDelay = time.Minute + retryConnectMaxDelay = 10 * time.Minute + retryConnectJitter = time.Minute + connectTimeout = 15 * time.Second + writeTimeout = 10 * time.Second + upstreamMessageDelay = 2 * time.Second + upstreamMessageBurst = 10 + backlogTimeout = 10 * time.Second + handleDownstreamMessageTimeout = 10 * time.Second + downstreamRegisterTimeout = 30 * time.Second + chatHistoryLimit = 1000 + backlogLimit = 4000 +) + +type Logger interface { + Printf(format string, v ...interface{}) + Debugf(format string, v ...interface{}) +} + +type logger struct { + *log.Logger + debug bool +} + +func (l logger) Debugf(format string, v ...interface{}) { + if !l.debug { + return + } + l.Logger.Printf(format, v...) +} + +func NewLogger(out io.Writer, debug bool) Logger { + return logger{ + Logger: log.New(log.Writer(), "", log.LstdFlags), + debug: debug, + } +} + +type prefixLogger struct { + logger Logger + prefix string +} + +var _ Logger = (*prefixLogger)(nil) + +func (l *prefixLogger) Printf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Printf("%v"+format, v...) +} + +func (l *prefixLogger) Debugf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Debugf("%v"+format, v...) +} + +type int64Gauge struct { + v int64 // atomic +} + +func (g *int64Gauge) Add(delta int64) { + atomic.AddInt64(&g.v, delta) +} + +func (g *int64Gauge) Value() int64 { + return atomic.LoadInt64(&g.v) +} + +func (g *int64Gauge) Float64() float64 { + return float64(g.Value()) +} + +type retryListener struct { + net.Listener + Logger Logger + + delay time.Duration +} + +func (ln *retryListener) Accept() (net.Conn, error) { + for { + conn, err := ln.Listener.Accept() + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if ln.delay == 0 { + ln.delay = 5 * time.Millisecond + } else { + ln.delay *= 2 + } + if max := 1 * time.Second; ln.delay > max { + ln.delay = max + } + if ln.Logger != nil { + ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err) + } + time.Sleep(ln.delay) + } else { + ln.delay = 0 + return conn, err + } + } +} + +type Config struct { + Hostname string + Title string + LogPath string + MaxUserNetworks int + MultiUpstream bool + MOTD string + UpstreamUserIPs []*net.IPNet +} + +type Server struct { + Logger Logger + + config atomic.Value // *Config + db Database + stopWG sync.WaitGroup + + lock sync.Mutex + listeners map[net.Listener]struct{} + users map[string]*user +} + +func NewServer(db Database) *Server { + srv := &Server{ + Logger: NewLogger(log.Writer(), true), + db: db, + listeners: make(map[net.Listener]struct{}), + users: make(map[string]*user), + } + srv.config.Store(&Config{ + Hostname: "localhost", + MaxUserNetworks: -1, + MultiUpstream: true, + }) + return srv +} + +func (s *Server) prefix() *irc.Prefix { + return &irc.Prefix{Name: s.Config().Hostname} +} + +func (s *Server) Config() *Config { + return s.config.Load().(*Config) +} + +func (s *Server) SetConfig(cfg *Config) { + s.config.Store(cfg) +} + +func (s *Server) Start() error { + users, err := s.db.ListUsers(context.TODO()) + if err != nil { + return err + } + + s.lock.Lock() + for i := range users { + s.addUserLocked(&users[i]) + } + s.lock.Unlock() + + return nil +} + +func (s *Server) Shutdown() { + s.lock.Lock() + for ln := range s.listeners { + if err := ln.Close(); err != nil { + s.Logger.Printf("failed to stop listener: %v", err) + } + } + for _, u := range s.users { + u.events <- eventStop{} + } + s.lock.Unlock() + + s.stopWG.Wait() + + if err := s.db.Close(); err != nil { + s.Logger.Printf("failed to close DB: %v", err) + } +} + +func (s *Server) createUser(ctx context.Context, user *User) (*user, error) { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.users[user.Username]; ok { + return nil, fmt.Errorf("user %q already exists", user.Username) + } + + err := s.db.StoreUser(ctx, user) + if err != nil { + return nil, fmt.Errorf("could not create user in db: %v", err) + } + + return s.addUserLocked(user), nil +} + +func (s *Server) forEachUser(f func(*user)) { + s.lock.Lock() + for _, u := range s.users { + f(u) + } + s.lock.Unlock() +} + +func (s *Server) getUser(name string) *user { + s.lock.Lock() + u := s.users[name] + s.lock.Unlock() + return u +} + +func (s *Server) addUserLocked(user *User) *user { + s.Logger.Printf("starting bouncer for user %q", user.Username) + u := newUser(s, user) + s.users[u.Username] = u + + s.stopWG.Add(1) + + go func() { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack()) + } + + s.lock.Lock() + delete(s.users, u.Username) + s.lock.Unlock() + + s.stopWG.Done() + }() + + u.run() + }() + + return u +} + +var lastDownstreamID uint64 = 0 + +func (s *Server) handle(ic ircConn) { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack()) + } + }() + + id := atomic.AddUint64(&lastDownstreamID, 1) + dc := newDownstreamConn(s, ic, id) + if err := dc.runUntilRegistered(); err != nil { + if !errors.Is(err, io.EOF) { + dc.logger.Printf("%v", err) + } + } else { + dc.user.events <- eventDownstreamConnected{dc} + if err := dc.readMessages(dc.user.events); err != nil { + dc.logger.Printf("%v", err) + } + dc.user.events <- eventDownstreamDisconnected{dc} + } + dc.Close() +} + +func (s *Server) Serve(ln net.Listener) error { + ln = &retryListener{ + Listener: ln, + Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())}, + } + + s.lock.Lock() + s.listeners[ln] = struct{}{} + s.lock.Unlock() + + s.stopWG.Add(1) + + defer func() { + s.lock.Lock() + delete(s.listeners, ln) + s.lock.Unlock() + + s.stopWG.Done() + }() + + for { + conn, err := ln.Accept() + if isErrClosed(err) { + return nil + } else if err != nil { + return fmt.Errorf("failed to accept connection: %v", err) + } + + go s.handle(newNetIRCConn(conn)) + } +} + +type ServerStats struct { + Users int + Downstreams int64 + Upstreams int64 +} + +func (s *Server) Stats() *ServerStats { + var stats ServerStats + s.lock.Lock() + stats.Users = len(s.users) + s.lock.Unlock() + return &stats +} diff --git a/branches/master/server_test.go b/branches/master/server_test.go new file mode 100644 index 0000000..cf0adca --- /dev/null +++ b/branches/master/server_test.go @@ -0,0 +1,207 @@ +package suika + +import ( + "context" + "net" + "testing" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +var testServerPrefix = &irc.Prefix{Name: "suika-test-server"} + +const ( + testUsername = "suika-test-user" + testPassword = testUsername +) + +func createTempSqliteDB(t *testing.T) Database { + db, err := OpenDB("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + // :memory: will open a separate database for each new connection. Make + // sure the sql package only uses a single connection. An alternative + // solution is to use "file::memory:?cache=shared". + db.(*SqliteDB).db.SetMaxOpenConns(1) + return db +} + +func createTempPostgresDB(t *testing.T) Database { + db := &PostgresDB{db: openTempPostgresDB(t)} + if err := db.upgrade(); err != nil { + t.Fatalf("failed to upgrade PostgreSQL database: %v", err) + } + + return db +} + +func createTestUser(t *testing.T, db Database) *User { + hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("failed to generate bcrypt hash: %v", err) + } + + record := &User{Username: testUsername, Password: string(hashed)} + if err := db.StoreUser(context.Background(), record); err != nil { + t.Fatalf("failed to store test user: %v", err) + } + + return record +} + +func createTestDownstream(t *testing.T, srv *Server) ircConn { + c1, c2 := net.Pipe() + go srv.handle(newNetIRCConn(c1)) + return newNetIRCConn(c2) +} + +func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to create TCP listener: %v", err) + } + + network := &Network{ + Name: "testnet", + Addr: "irc://" + ln.Addr().String(), + Nick: user.Username, + Enabled: true, + } + if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil { + t.Fatalf("failed to store test network: %v", err) + } + + return network, ln +} + +func mustAccept(t *testing.T, ln net.Listener) ircConn { + c, err := ln.Accept() + if err != nil { + t.Fatalf("failed accepting connection: %v", err) + } + return newNetIRCConn(c) +} + +func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message { + msg, err := c.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message (want %q): %v", cmd, err) + } + if msg.Command != cmd { + t.Fatalf("invalid message received: want %q, got: %v", cmd, msg) + } + return msg +} + +func registerDownstreamConn(t *testing.T, c ircConn, network *Network) { + c.WriteMessage(&irc.Message{ + Command: "PASS", + Params: []string{testPassword}, + }) + c.WriteMessage(&irc.Message{ + Command: "NICK", + Params: []string{testUsername}, + }) + c.WriteMessage(&irc.Message{ + Command: "USER", + Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername}, + }) + + expectMessage(t, c, irc.RPL_WELCOME) +} + +func registerUpstreamConn(t *testing.T, c ircConn) { + msg := expectMessage(t, c, "CAP") + if msg.Params[0] != "LS" { + t.Fatalf("invalid CAP LS: got: %v", msg) + } + msg = expectMessage(t, c, "NICK") + nick := msg.Params[0] + if nick != testUsername { + t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg) + } + expectMessage(t, c, "USER") + + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_WELCOME, + Params: []string{nick, "Welcome!"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_YOURHOST, + Params: []string{nick, "Your host is suika-test-server"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_CREATED, + Params: []string{nick, "Who cares when the server was created?"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_MYINFO, + Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.ERR_NOMOTD, + Params: []string{nick, "No MOTD"}, + }) +} + +func testServer(t *testing.T, db Database) { + user := createTestUser(t, db) + network, upstream := createTestUpstream(t, db, user) + defer upstream.Close() + + srv := NewServer(db) + if err := srv.Start(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer srv.Shutdown() + + uc := mustAccept(t, upstream) + defer uc.Close() + registerUpstreamConn(t, uc) + + dc := createTestDownstream(t, srv) + defer dc.Close() + registerDownstreamConn(t, dc, network) + + noticeText := "This is a very important server notice." + uc.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: "NOTICE", + Params: []string{testUsername, noticeText}, + }) + + var msg *irc.Message + for { + var err error + msg, err = dc.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message: %v", err) + } + if msg.Command == "NOTICE" { + break + } + } + + if msg.Params[1] != noticeText { + t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg) + } +} + +func TestServer(t *testing.T) { + t.Run("sqlite", func(t *testing.T) { + db := createTempSqliteDB(t) + testServer(t, db) + }) + + t.Run("postgres", func(t *testing.T) { + db := createTempPostgresDB(t) + testServer(t, db) + }) +} diff --git a/branches/master/service.go b/branches/master/service.go new file mode 100644 index 0000000..07226ef --- /dev/null +++ b/branches/master/service.go @@ -0,0 +1,1150 @@ +package suika + +import ( + "context" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "flag" + "fmt" + "io/ioutil" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +const ( + serviceNick = "BouncerServ" + serviceNickCM = "bouncerserv" + serviceRealname = "suika bouncer service" +) + +// maxRSABits is the maximum number of RSA key bits used when generating a new +// private key. +const maxRSABits = 8192 + +var servicePrefix = &irc.Prefix{ + Name: serviceNick, + User: serviceNick, + Host: serviceNick, +} + +type serviceCommandSet map[string]*serviceCommand + +type serviceCommand struct { + usage string + desc string + handle func(ctx context.Context, dc *downstreamConn, params []string) error + children serviceCommandSet + admin bool +} + +func sendServiceNOTICE(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{dc.nick, text}, + }) +} + +func sendServicePRIVMSG(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "PRIVMSG", + Params: []string{dc.nick, text}, + }) +} + +func splitWords(s string) ([]string, error) { + var words []string + var lastWord strings.Builder + escape := false + prev := ' ' + wordDelim := ' ' + + for _, r := range s { + if escape { + // last char was a backslash, write the byte as-is. + lastWord.WriteRune(r) + escape = false + } else if r == '\\' { + escape = true + } else if wordDelim == ' ' && unicode.IsSpace(r) { + // end of last word + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + lastWord.Reset() + } + } else if r == wordDelim { + // wordDelim is either " or ', switch back to + // space-delimited words. + wordDelim = ' ' + } else if r == '"' || r == '\'' { + if wordDelim == ' ' { + // start of (double-)quoted word + wordDelim = r + } else { + // either wordDelim is " and r is ' or vice-versa + lastWord.WriteRune(r) + } + } else { + lastWord.WriteRune(r) + } + + prev = r + } + + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + } + + if wordDelim != ' ' { + return nil, fmt.Errorf("unterminated quoted string") + } + if escape { + return nil, fmt.Errorf("unterminated backslash sequence") + } + + return words, nil +} + +func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) { + words, err := splitWords(text) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err)) + return + } + + cmd, params, err := serviceCommands.Get(words) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err)) + return + } + if cmd.admin && !dc.user.Admin { + sendServicePRIVMSG(dc, "error: you must be an admin to use this command") + return + } + + if cmd.handle == nil { + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + // Pretend the command does not exist if it has neither children nor handler. + // This is obviously a bug but it is better to not die anyway. + dc.logger.Printf("command without handler and subcommands invoked:", words[0]) + sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0])) + } + return + } + + if err := cmd.handle(ctx, dc, params); err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err)) + } +} + +func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) { + if len(params) == 0 { + return nil, nil, fmt.Errorf("no command specified") + } + + name := params[0] + params = params[1:] + + cmd, ok := cmds[name] + if !ok { + for k := range cmds { + if !strings.HasPrefix(k, name) { + continue + } + if cmd != nil { + return nil, params, fmt.Errorf("command %q is ambiguous", name) + } + cmd = cmds[k] + } + } + if cmd == nil { + return nil, params, fmt.Errorf("command %q not found", name) + } + + if len(params) == 0 || len(cmd.children) == 0 { + return cmd, params, nil + } + return cmd.children.Get(params) +} + +func (cmds serviceCommandSet) Names() []string { + l := make([]string, 0, len(cmds)) + for name := range cmds { + l = append(l, name) + } + sort.Strings(l) + return l +} + +var serviceCommands serviceCommandSet + +func init() { + serviceCommands = serviceCommandSet{ + "help": { + usage: "[command]", + desc: "print help message", + handle: handleServiceHelp, + }, + "network": { + children: serviceCommandSet{ + "create": { + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "add a new network", + handle: handleServiceNetworkCreate, + }, + "status": { + desc: "show a list of saved networks and their current status", + handle: handleServiceNetworkStatus, + }, + "update": { + usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "update a network", + handle: handleServiceNetworkUpdate, + }, + "delete": { + usage: "[name]", + desc: "delete a network", + handle: handleServiceNetworkDelete, + }, + "quote": { + usage: "[name] ", + desc: "send a raw line to a network", + handle: handleServiceNetworkQuote, + }, + }, + }, + "certfp": { + children: serviceCommandSet{ + "generate": { + usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]", + desc: "generate a new self-signed certificate, defaults to using RSA-3072 key", + handle: handleServiceCertFPGenerate, + }, + "fingerprint": { + usage: "[-network name]", + desc: "show fingerprints of certificate", + handle: handleServiceCertFPFingerprints, + }, + }, + }, + "sasl": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show SASL status", + handle: handleServiceSASLStatus, + }, + "set-plain": { + usage: "[-network name] ", + desc: "set SASL PLAIN credentials", + handle: handleServiceSASLSetPlain, + }, + "reset": { + usage: "[-network name]", + desc: "disable SASL authentication and remove stored credentials", + handle: handleServiceSASLReset, + }, + }, + }, + "user": { + children: serviceCommandSet{ + "create": { + usage: "-username -password [-realname ] [-admin]", + desc: "create a new suika user", + handle: handleUserCreate, + admin: true, + }, + "update": { + usage: "[-password ] [-realname ]", + desc: "update the current user", + handle: handleUserUpdate, + }, + "delete": { + usage: "", + desc: "delete a user", + handle: handleUserDelete, + admin: true, + }, + }, + }, + "channel": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show a list of saved channels and their current status", + handle: handleServiceChannelStatus, + }, + "update": { + usage: " [-relay-detached ] [-reattach-on ] [-detach-after ] [-detach-on ]", + desc: "update a channel", + handle: handleServiceChannelUpdate, + }, + }, + }, + "server": { + children: serviceCommandSet{ + "status": { + desc: "show server statistics", + handle: handleServiceServerStatus, + admin: true, + }, + "notice": { + desc: "broadcast a notice to all connected bouncer users", + handle: handleServiceServerNotice, + admin: true, + }, + }, + admin: true, + }, + } +} + +func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) { + for _, name := range cmds.Names() { + cmd := cmds[name] + if cmd.admin && !admin { + continue + } + words := append(prefix, name) + if len(cmd.children) == 0 { + s := strings.Join(words, " ") + *l = append(*l, s) + } else { + appendServiceCommandSetHelp(cmd.children, words, admin, l) + } + } +} + +func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) > 0 { + cmd, rest, err := serviceCommands.Get(params) + if err != nil { + return err + } + words := params[:len(params)-len(rest)] + + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + text := strings.Join(words, " ") + if cmd.usage != "" { + text += " " + cmd.usage + } + text += ": " + cmd.desc + + sendServicePRIVMSG(dc, text) + } + } else { + var l []string + appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } + return nil +} + +func newFlagSet() *flag.FlagSet { + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.SetOutput(ioutil.Discard) + return fs +} + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +// stringPtrFlag is a flag value populating a string pointer. This allows to +// disambiguate between a flag that hasn't been set and a flag that has been +// set to an empty string. +type stringPtrFlag struct { + ptr **string +} + +func (f stringPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return **f.ptr +} + +func (f stringPtrFlag) Set(s string) error { + *f.ptr = &s + return nil +} + +type boolPtrFlag struct { + ptr **bool +} + +func (f boolPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return strconv.FormatBool(**f.ptr) +} + +func (f boolPtrFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + *f.ptr = &v + return nil +} + +func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) { + name, params := popArg(params) + if name == "" { + if dc.network == nil { + return nil, params, fmt.Errorf("no network selected, a name argument is required") + } + return dc.network, params, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, params, fmt.Errorf("unknown network %q", name) + } + return net, params, nil + } +} + +type networkFlagSet struct { + *flag.FlagSet + Addr, Name, Nick, Username, Pass, Realname *string + Enabled *bool + ConnectCommands []string +} + +func newNetworkFlagSet() *networkFlagSet { + fs := &networkFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.Addr}, "addr", "") + fs.Var(stringPtrFlag{&fs.Name}, "name", "") + fs.Var(stringPtrFlag{&fs.Nick}, "nick", "") + fs.Var(stringPtrFlag{&fs.Username}, "username", "") + fs.Var(stringPtrFlag{&fs.Pass}, "pass", "") + fs.Var(stringPtrFlag{&fs.Realname}, "realname", "") + fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "") + fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "") + return fs +} + +func (fs *networkFlagSet) update(network *Network) error { + if fs.Addr != nil { + if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 { + scheme := addrParts[0] + switch scheme { + case "ircs", "irc", "unix": + default: + return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme) + } + } + network.Addr = *fs.Addr + } + if fs.Name != nil { + network.Name = *fs.Name + } + if fs.Nick != nil { + network.Nick = *fs.Nick + } + if fs.Username != nil { + network.Username = *fs.Username + } + if fs.Pass != nil { + network.Pass = *fs.Pass + } + if fs.Realname != nil { + network.Realname = *fs.Realname + } + if fs.Enabled != nil { + network.Enabled = *fs.Enabled + } + if fs.ConnectCommands != nil { + if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" { + network.ConnectCommands = nil + } else { + for _, command := range fs.ConnectCommands { + _, err := irc.ParseMessage(command) + if err != nil { + return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) + } + } + network.ConnectCommands = fs.ConnectCommands + } + } + return nil +} + +func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + if fs.Addr == nil { + return fmt.Errorf("flag -addr is required") + } + + record := &Network{ + Addr: *fs.Addr, + Enabled: true, + } + if err := fs.update(record); err != nil { + return err + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return fmt.Errorf("could not create network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName())) + return nil +} + +func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error { + n := 0 + for _, net := range dc.user.networks { + var statuses []string + var details string + if uc := net.conn; uc != nil { + if dc.nick != uc.nick { + statuses = append(statuses, "connected as "+uc.nick) + } else { + statuses = append(statuses, "connected") + } + details = fmt.Sprintf("%v channels", uc.channels.Len()) + } else if !net.Enabled { + statuses = append(statuses, "disabled") + } else { + statuses = append(statuses, "disconnected") + if net.lastError != nil { + details = net.lastError.Error() + } + } + + if net == dc.network { + statuses = append(statuses, "current") + } + + name := net.GetName() + if name != net.Addr { + name = fmt.Sprintf("%v (%v)", name, net.Addr) + } + + s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", ")) + if details != "" { + s += ": " + details + } + sendServicePRIVMSG(dc, s) + + n++ + } + + if n == 0 { + sendServicePRIVMSG(dc, `No network configured, add one with "network create".`) + } + + return nil +} + +func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + + record := net.Network // copy network record because we'll mutate it + if err := fs.update(&record); err != nil { + return err + } + + network, err := dc.user.updateNetwork(ctx, &record) + if err != nil { + return fmt.Errorf("could not update network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName())) + return nil +} + +func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName())) + return nil +} + +func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 && len(params) != 2 { + return fmt.Errorf("expected one or two arguments") + } + + raw := params[len(params)-1] + params = params[:len(params)-1] + + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + uc := net.conn + if uc == nil { + return fmt.Errorf("network %q is not currently connected", net.GetName()) + } + + m, err := irc.ParseMessage(raw) + if err != nil { + return fmt.Errorf("failed to parse command %q: %v", raw, err) + } + uc.SendMessage(ctx, m) + + sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName())) + return nil +} + +func sendCertfpFingerprints(dc *downstreamConn, cert []byte) { + sha1Sum := sha1.Sum(cert) + sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:])) + sha256Sum := sha256.Sum256(cert) + sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:])) + sha512Sum := sha512.Sum512(cert) + sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:])) +} + +func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) { + if name == "" { + if dc.network == nil { + return nil, fmt.Errorf("no network selected, -network is required") + } + return dc.network, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, fmt.Errorf("unknown network %q", name) + } + return net, nil + } +} + +func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)") + bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA") + + if err := fs.Parse(params); err != nil { + return err + } + + if *bits <= 0 || *bits > maxRSABits { + return fmt.Errorf("invalid value for -bits") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + privKey, cert, err := generateCertFP(*keyType, *bits) + if err != nil { + return err + } + + net.SASL.External.CertBlob = cert + net.SASL.External.PrivKeyBlob = privKey + net.SASL.Mechanism = "EXTERNAL" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "certificate generated") + sendCertfpFingerprints(dc, cert) + return nil +} + +func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + if net.SASL.Mechanism != "EXTERNAL" { + return fmt.Errorf("CertFP not set up") + } + + sendCertfpFingerprints(dc, net.SASL.External.CertBlob) + return nil +} + +func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + switch net.SASL.Mechanism { + case "PLAIN": + sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username)) + case "EXTERNAL": + sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled") + case "": + sendServicePRIVMSG(dc, "SASL is disabled") + } + + if uc := net.conn; uc != nil { + if uc.account != "" { + sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account)) + } else { + sendServicePRIVMSG(dc, "Unauthenticated on upstream network") + } + } else { + sendServicePRIVMSG(dc, "Disconnected from upstream network") + } + + return nil +} + +func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + if len(fs.Args()) != 2 { + return fmt.Errorf("expected exactly 2 arguments") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = fs.Arg(0) + net.SASL.Plain.Password = fs.Arg(1) + net.SASL.Mechanism = "PLAIN" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials saved") + return nil +} + +func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = "" + net.SASL.Plain.Password = "" + net.SASL.External.CertBlob = nil + net.SASL.External.PrivKeyBlob = nil + net.SASL.Mechanism = "" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials reset") + return nil +} + +func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + username := fs.String("username", "", "") + password := fs.String("password", "", "") + realname := fs.String("realname", "", "") + admin := fs.Bool("admin", false, "") + + if err := fs.Parse(params); err != nil { + return err + } + if *username == "" { + return fmt.Errorf("flag -username is required") + } + if *password == "" { + return fmt.Errorf("flag -password is required") + } + + hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + + user := &User{ + Username: *username, + Password: string(hashed), + Realname: *realname, + Admin: *admin, + } + if _, err := dc.srv.createUser(ctx, user); err != nil { + return fmt.Errorf("could not create user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username)) + return nil +} + +func popArg(params []string) (string, []string) { + if len(params) > 0 && !strings.HasPrefix(params[0], "-") { + return params[0], params[1:] + } + return "", params +} + +func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + var password, realname *string + var admin *bool + fs := newFlagSet() + fs.Var(stringPtrFlag{&password}, "password", "") + fs.Var(stringPtrFlag{&realname}, "realname", "") + fs.Var(boolPtrFlag{&admin}, "admin", "") + + username, params := popArg(params) + if err := fs.Parse(params); err != nil { + return err + } + if len(fs.Args()) > 0 { + return fmt.Errorf("unexpected argument") + } + + var hashed *string + if password != nil { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + hashedStr := string(hashedBytes) + hashed = &hashedStr + } + + if username != "" && username != dc.user.Username { + if !dc.user.Admin { + return fmt.Errorf("you must be an admin to update other users") + } + if realname != nil { + return fmt.Errorf("cannot update -realname of other user") + } + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + done := make(chan error, 1) + event := eventUserUpdate{ + password: hashed, + admin: admin, + done: done, + } + select { + case <-ctx.Done(): + return ctx.Err() + case u.events <- event: + } + // TODO: send context to the other side + if err := <-done; err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username)) + } else { + // copy the user record because we'll mutate it + record := dc.user.User + + if hashed != nil { + record.Password = *hashed + } + if realname != nil { + record.Realname = *realname + } + if admin != nil { + return fmt.Errorf("cannot update -admin of own user") + } + + if err := dc.user.updateUser(ctx, &record); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username)) + } + + return nil +} + +func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + username := params[0] + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + u.stop() + + if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil { + return fmt.Errorf("failed to delete user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username)) + return nil +} + +func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error { + var defaultNetworkName string + if dc.network != nil { + defaultNetworkName = dc.network.GetName() + } + + fs := newFlagSet() + networkName := fs.String("network", defaultNetworkName, "") + + if err := fs.Parse(params); err != nil { + return err + } + + n := 0 + + sendNetwork := func(net *network) { + var channels []*Channel + for _, entry := range net.channels.innerMap { + channels = append(channels, entry.value.(*Channel)) + } + + sort.Slice(channels, func(i, j int) bool { + return strings.ReplaceAll(channels[i].Name, "#", "") < + strings.ReplaceAll(channels[j].Name, "#", "") + }) + + for _, ch := range channels { + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + } + + name := ch.Name + if *networkName == "" { + name += "/" + net.GetName() + } + + var status string + if uch != nil { + status = "joined" + } else if net.conn != nil { + status = "parted" + } else { + status = "disconnected" + } + + if ch.Detached { + status += ", detached" + } + + s := fmt.Sprintf("%v [%v]", name, status) + sendServicePRIVMSG(dc, s) + + n++ + } + } + + if *networkName == "" { + for _, net := range dc.user.networks { + sendNetwork(net) + } + } else { + net := dc.user.getNetwork(*networkName) + if net == nil { + return fmt.Errorf("unknown network %q", *networkName) + } + sendNetwork(net) + } + + if n == 0 { + sendServicePRIVMSG(dc, "No channel configured.") + } + + return nil +} + +type channelFlagSet struct { + *flag.FlagSet + RelayDetached, ReattachOn, DetachAfter, DetachOn *string +} + +func newChannelFlagSet() *channelFlagSet { + fs := &channelFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "") + fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "") + fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "") + fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "") + return fs +} + +func (fs *channelFlagSet) update(channel *Channel) error { + if fs.RelayDetached != nil { + filter, err := parseFilter(*fs.RelayDetached) + if err != nil { + return err + } + channel.RelayDetached = filter + } + if fs.ReattachOn != nil { + filter, err := parseFilter(*fs.ReattachOn) + if err != nil { + return err + } + channel.ReattachOn = filter + } + if fs.DetachAfter != nil { + dur, err := time.ParseDuration(*fs.DetachAfter) + if err != nil || dur < 0 { + return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter) + } + channel.DetachAfter = dur + } + if fs.DetachOn != nil { + filter, err := parseFilter(*fs.DetachOn) + if err != nil { + return err + } + channel.DetachOn = filter + } + return nil +} + +func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) < 1 { + return fmt.Errorf("expected at least one argument") + } + name := params[0] + + fs := newChannelFlagSet() + if err := fs.Parse(params[1:]); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return fmt.Errorf("unknown channel %q", name) + } + + ch := uc.network.channels.Value(upstreamName) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + + if err := fs.update(ch); err != nil { + return err + } + + uc.updateChannelAutoDetach(upstreamName) + + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + return fmt.Errorf("failed to update channel: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name)) + return nil +} +func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error { + dbStats, err := dc.user.srv.db.Stats(ctx) + if err != nil { + return err + } + serverStats := dc.user.srv.Stats() + sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels)) + return nil +} + +func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + text := params[0] + + dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text) + + broadcastMsg := &irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{"$" + dc.srv.Config().Hostname, text}, + } + var err error + sent := 0 + total := 0 + dc.srv.forEachUser(func(u *user) { + total++ + select { + case <-ctx.Done(): + err = ctx.Err() + case u.events <- eventBroadcast{broadcastMsg}: + sent++ + } + }) + + dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total) + sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total)) + + return err +} diff --git a/branches/master/service_test.go b/branches/master/service_test.go new file mode 100644 index 0000000..1f3fe6b --- /dev/null +++ b/branches/master/service_test.go @@ -0,0 +1,54 @@ +package suika + +import ( + "testing" +) + +func assertSplit(t *testing.T, input string, expected []string) { + actual, err := splitWords(input) + if err != nil { + t.Errorf("%q: %v", input, err) + return + } + if len(actual) != len(expected) { + t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual) + return + } + for i := 0; i < len(actual); i++ { + if actual[i] != expected[i] { + t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual) + } + } +} + +func TestSplit(t *testing.T) { + assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{ + "ch", + "up", + "#suika", + "relay-detached", + "message", + }) + assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{ + "net", + "update", + "\"free\"node", + "-pass", + "political \"stance\" desu!", + "-realname", + "", + "-nick", + "lee", + }) + assertSplit(t, "Omedeto,\\ Yui! ''", []string{ + "Omedeto, Yui!", + "", + }) + + if _, err := splitWords("end of 'file"); err == nil { + t.Errorf("expected error on unterminated single quote") + } + if _, err := splitWords("end of backquote \\"); err == nil { + t.Errorf("expected error on unterminated backquote sequence") + } +} diff --git a/branches/master/suika_psql_schema.sql b/branches/master/suika_psql_schema.sql new file mode 100644 index 0000000..d148b4d --- /dev/null +++ b/branches/master/suika_psql_schema.sql @@ -0,0 +1,58 @@ +CREATE TABLE IF NOT EXISTS "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + +CREATE TABLE IF NOT EXISTS "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255), + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism sasl_mechanism, + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA, + sasl_external_key BYTEA, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); +CREATE TABLE IF NOT EXISTS "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); +CREATE TABLE IF NOT EXISTS "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) +); + diff --git a/branches/master/suika_sqlite_schema.sql b/branches/master/suika_sqlite_schema.sql new file mode 100644 index 0000000..517d06f --- /dev/null +++ b/branches/master/suika_sqlite_schema.sql @@ -0,0 +1,61 @@ +CREATE TABLE IF NOT EXISTS User ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password TEXT, + admin INTEGER NOT NULL DEFAULT 0, + realname TEXT +); +CREATE TABLE IF NOT EXISTS Network ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); +CREATE TABLE IF NOT EXISTS Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name TEXT NOT NULL, + key TEXT, + detached INTEGER NOT NULL DEFAULT 0, + detached_internal_msgid TEXT, + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + client TEXT, + internal_msgid TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) +); + +CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) +); + diff --git a/branches/master/upstream.go b/branches/master/upstream.go new file mode 100644 index 0000000..285dca4 --- /dev/null +++ b/branches/master/upstream.go @@ -0,0 +1,2182 @@ +package suika + +import ( + "context" + "crypto" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "gopkg.in/irc.v3" +) + +// permanentUpstreamCaps is the static list of upstream capabilities always +// requested when supported. +var permanentUpstreamCaps = map[string]bool{ + "account-notify": true, + "account-tag": true, + "away-notify": true, + "batch": true, + "extended-join": true, + "invite-notify": true, + "labeled-response": true, + "message-tags": true, + "multi-prefix": true, + "sasl": true, + "server-time": true, + "setname": true, + + "draft/account-registration": true, + "draft/extended-monitor": true, +} + +type registrationError struct { + *irc.Message +} + +func (err registrationError) Error() string { + return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason()) +} + +func (err registrationError) Reason() string { + if len(err.Params) > 0 { + return err.Params[len(err.Params)-1] + } + return err.Command +} + +func (err registrationError) Temporary() bool { + // Only return false if we're 100% sure that fixing the error requires a + // network configuration change + switch err.Command { + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME: + return false + case "FAIL": + return err.Params[1] != "ACCOUNT_REQUIRED" + default: + return true + } +} + +type upstreamChannel struct { + Name string + conn *upstreamConn + Topic string + TopicWho *irc.Prefix + TopicTime time.Time + Status channelStatus + modes channelModes + creationTime string + Members membershipsCasemapMap + complete bool + detachTimer *time.Timer +} + +func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) { + if uc.detachTimer != nil { + uc.detachTimer.Stop() + uc.detachTimer = nil + } + + if dur == 0 { + return + } + + uc.detachTimer = time.AfterFunc(dur, func() { + uc.conn.network.user.events <- eventChannelDetach{ + uc: uc.conn, + name: uc.Name, + } + }) +} + +type pendingUpstreamCommand struct { + downstreamID uint64 + msg *irc.Message +} + +type upstreamConn struct { + conn + + network *network + user *user + + serverName string + availableUserModes string + availableChannelModes map[byte]channelModeType + availableChannelTypes string + availableMemberships []membership + isupport map[string]*string + + registered bool + nick string + nickCM string + username string + realname string + modes userModes + channels upstreamChannelCasemapMap + supportedCaps map[string]string + caps map[string]bool + batches map[string]batch + away bool + account string + nextLabelID uint64 + monitored monitorCasemapMap + + saslClient sasl.Client + saslStarted bool + + casemapIsSet bool + + // Queue of commands in progress, indexed by type. The first entry has been + // sent to the server and is awaiting reply. The following entries have not + // been sent yet. + pendingCmds map[string][]pendingUpstreamCommand + + gotMotd bool +} + +func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) { + logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())} + + dialer := net.Dialer{Timeout: connectTimeout} + + u, err := network.URL() + if err != nil { + return nil, err + } + + var netConn net.Conn + switch u.Scheme { + case "ircs": + addr := u.Host + host, _, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Host + addr = u.Host + ":6697" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to TLS server at address %q", addr) + + tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}} + if network.SASL.Mechanism == "EXTERNAL" { + if network.SASL.External.CertBlob == nil { + return nil, fmt.Errorf("missing certificate for authentication") + } + if network.SASL.External.PrivKeyBlob == nil { + return nil, fmt.Errorf("missing private key for authentication") + } + key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + tlsConfig.Certificates = []tls.Certificate{ + { + Certificate: [][]byte{network.SASL.External.CertBlob}, + PrivateKey: key.(crypto.PrivateKey), + }, + } + logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) + } + + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + + // Don't do the TLS handshake immediately, because we need to register + // the new connection with identd ASAP. + netConn = tls.Client(netConn, tlsConfig) + case "irc": + addr := u.Host + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = u.Host + addr = u.Host + ":6667" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to plain-text server at address %q", addr) + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + case "irc+unix", "unix": + logger.Printf("connecting to Unix socket at path %q", u.Path) + netConn, err = dialer.DialContext(ctx, "unix", u.Path) + if err != nil { + return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err) + } + default: + return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme) + } + + options := connOptions{ + Logger: logger, + RateLimitDelay: upstreamMessageDelay, + RateLimitBurst: upstreamMessageBurst, + } + + uc := &upstreamConn{ + conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), + network: network, + user: network.user, + channels: upstreamChannelCasemapMap{newCasemapMap(0)}, + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + batches: make(map[string]batch), + availableChannelTypes: stdChannelTypes, + availableChannelModes: stdChannelModes, + availableMemberships: stdMemberships, + isupport: make(map[string]*string), + pendingCmds: make(map[string][]pendingUpstreamCommand), + monitored: monitorCasemapMap{newCasemapMap(0)}, + } + return uc, nil +} + +func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { + uc.network.forEachDownstream(f) +} + +func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) { + uc.forEachDownstream(func(dc *downstreamConn) { + if id != 0 && id != dc.id { + return + } + f(dc) + }) +} + +func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn { + for _, dc := range uc.user.downstreamConns { + if dc.id == id { + return dc + } + } + return nil +} + +func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { + ch := uc.channels.Value(name) + if ch == nil { + return nil, fmt.Errorf("unknown channel %q", name) + } + return ch, nil +} + +func (uc *upstreamConn) isChannel(entity string) bool { + return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0])) +} + +func (uc *upstreamConn) isOurNick(nick string) bool { + return uc.nickCM == uc.network.casemap(nick) +} + +func (uc *upstreamConn) abortPendingCommands() { + for _, l := range uc.pendingCmds { + for _, pendingCmd := range l { + dc := uc.downstreamByID(pendingCmd.downstreamID) + if dc == nil { + continue + } + + switch pendingCmd.msg.Command { + case "LIST": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Command aborted"}, + }) + case "WHO": + mask := "*" + if len(pendingCmd.msg.Params) > 0 { + mask = pendingCmd.msg.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "Command aborted"}, + }) + case "AUTHENTICATE": + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + case "REGISTER", "VERIFY": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "FAIL", + Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"}, + }) + default: + panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command)) + } + } + } + + uc.pendingCmds = make(map[string][]pendingUpstreamCommand) +} + +func (uc *upstreamConn) sendNextPendingCommand(cmd string) { + if len(uc.pendingCmds[cmd]) == 0 { + return + } + uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg) +} + +func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { + switch msg.Command { + case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY": + // Supported + default: + panic(fmt.Errorf("Unsupported pending command %q", msg.Command)) + } + + uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{ + downstreamID: dc.id, + msg: msg, + }) + + if len(uc.pendingCmds[msg.Command]) == 1 { + uc.sendNextPendingCommand(msg.Command) + } +} + +func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) { + if len(uc.pendingCmds[cmd]) == 0 { + return nil, nil + } + + pendingCmd := uc.pendingCmds[cmd][0] + return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg +} + +func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) { + dc, msg := uc.currentPendingCommand(cmd) + + if len(uc.pendingCmds[cmd]) > 0 { + copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:]) + uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1] + } + + uc.sendNextPendingCommand(cmd) + + return dc, msg +} + +func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) { + for cmd := range uc.pendingCmds { + // We can't cancel the currently running command stored in + // uc.pendingCmds[cmd][0] + for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- { + if uc.pendingCmds[cmd][i].downstreamID == downstreamID { + uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...) + } + } + } +} + +func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) { + memberships := make(memberships, 0, 4) + i := 0 + for _, m := range uc.availableMemberships { + if i >= len(s) { + break + } + if s[i] == m.Prefix { + memberships = append(memberships, m) + i++ + } + } + return &memberships, s[i:] +} + +func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + var label string + if l, ok := msg.GetTag("label"); ok { + label = l + delete(msg.Tags, "label") + } + + var msgBatch *batch + if batchName, ok := msg.GetTag("batch"); ok { + b, ok := uc.batches[batchName] + if !ok { + return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName) + } + msgBatch = &b + if label == "" { + label = msgBatch.Label + } + delete(msg.Tags, "batch") + } + + var downstreamID uint64 = 0 + if label != "" { + var labelOffset uint64 + n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset) + if err == nil && n < 2 { + err = errors.New("not enough arguments") + } + if err != nil { + return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err) + } + } + + if _, ok := msg.Tags["time"]; !ok { + msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now())) + } + + switch msg.Command { + case "PING": + uc.SendMessage(ctx, &irc.Message{ + Command: "PONG", + Params: msg.Params, + }) + return nil + case "NOTICE", "PRIVMSG", "TAGMSG": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var entity, text string + if msg.Command != "TAGMSG" { + if err := parseMessageParams(msg, &entity, &text); err != nil { + return err + } + } else { + if err := parseMessageParams(msg, &entity); err != nil { + return err + } + } + + if msg.Prefix.Name == serviceNick { + uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg) + break + } + if entity == serviceNick { + uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg) + break + } + + if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message + uc.produce("", msg, nil) + } else { // regular user message + target := entity + if uc.isOurNick(target) { + target = msg.Prefix.Name + } + + ch := uc.network.channels.Value(target) + if ch != nil && msg.Command != "TAGMSG" { + if ch.Detached { + uc.handleDetachedMessage(ctx, ch, msg) + } + + highlight := uc.network.isHighlight(msg) + if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) { + uc.updateChannelAutoDetach(target) + } + } + + uc.produce(target, msg, nil) + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, nil, &subCmd); err != nil { + return err + } + subCmd = strings.ToUpper(subCmd) + subParams := msg.Params[2:] + switch subCmd { + case "LS": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := subParams[len(subParams)-1] + more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*" + + uc.handleSupportedCaps(caps) + + if more { + break // wait to receive all capabilities + } + + uc.requestCaps() + + if uc.requestSASL() { + break // we'll send CAP END after authentication is completed + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + case "ACK", "NAK": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, name := range caps { + if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil { + return err + } + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + case "NEW": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + uc.handleSupportedCaps(subParams[0]) + uc.requestCaps() + case "DEL": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, c := range caps { + delete(uc.supportedCaps, c) + delete(uc.caps, c) + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + default: + uc.logger.Printf("unhandled message: %v", msg) + } + case "AUTHENTICATE": + if uc.saslClient == nil { + return fmt.Errorf("received unexpected AUTHENTICATE message") + } + + // TODO: if a challenge is 400 bytes long, buffer it + var challengeStr string + if err := parseMessageParams(msg, &challengeStr); err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + var challenge []byte + if challengeStr != "+" { + var err error + challenge, err = base64.StdEncoding.DecodeString(challengeStr) + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + } + + var resp []byte + var err error + if !uc.saslStarted { + _, resp, err = uc.saslClient.Start() + uc.saslStarted = true + } else { + resp, err = uc.saslClient.Next(challenge) + } + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + // <= instead of < because we need to send a final empty response if + // the last chunk is exactly 400 bytes long + for i := 0; i <= len(resp); i += maxSASLLength { + j := i + maxSASLLength + if j > len(resp) { + j = len(resp) + } + + chunk := resp[i:j] + + var respStr = "+" + if len(chunk) != 0 { + respStr = base64.StdEncoding.EncodeToString(chunk) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{respStr}, + }) + } + case irc.RPL_LOGGEDIN: + if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil { + return err + } + uc.logger.Printf("logged in with account %q", uc.account) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.RPL_LOGGEDOUT: + uc.account = "" + uc.logger.Printf("logged out") + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED: + var info string + if err := parseMessageParams(msg, nil, &info); err != nil { + return err + } + switch msg.Command { + case irc.ERR_NICKLOCKED: + uc.logger.Printf("invalid nick used with SASL authentication: %v", info) + case irc.ERR_SASLFAIL: + uc.logger.Printf("SASL authentication failed: %v", info) + case irc.ERR_SASLTOOLONG: + uc.logger.Printf("SASL message too long: %v", info) + } + + uc.saslClient = nil + uc.saslStarted = false + + if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil { + if msg.Command == irc.RPL_SASLSUCCESS { + uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword) + } + + dc.endSASL(msg) + } + + if !uc.registered { + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + } + case "REGISTER", "VERIFY": + if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil { + if msg.Command == "REGISTER" { + var account, password string + if err := parseMessageParams(msg, nil, &account); err != nil { + return err + } + if err := parseMessageParams(cmd, nil, nil, &password); err != nil { + return err + } + uc.network.autoSaveSASLPlain(ctx, account, password) + } + + dc.SendMessage(msg) + } + case irc.RPL_WELCOME: + if err := parseMessageParams(msg, &uc.nick); err != nil { + return err + } + + uc.registered = true + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("connection registered with nick %q", uc.nick) + + if uc.network.channels.Len() > 0 { + var channels, keys []string + for _, entry := range uc.network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, ch.Name) + keys = append(keys, ch.Key) + } + + for _, msg := range join(channels, keys) { + uc.SendMessage(ctx, msg) + } + } + case irc.RPL_MYINFO: + if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil { + return err + } + case irc.RPL_ISUPPORT: + if err := parseMessageParams(msg, nil, nil); err != nil { + return err + } + + var downstreamIsupport []string + for _, token := range msg.Params[1 : len(msg.Params)-1] { + parameter := token + var negate, hasValue bool + var value string + if strings.HasPrefix(token, "-") { + negate = true + token = token[1:] + } else if i := strings.IndexByte(token, '='); i >= 0 { + parameter = token[:i] + value = token[i+1:] + hasValue = true + } + + if hasValue { + uc.isupport[parameter] = &value + } else if !negate { + uc.isupport[parameter] = nil + } else { + delete(uc.isupport, parameter) + } + + var err error + switch parameter { + case "CASEMAPPING": + casemap, ok := parseCasemappingToken(value) + if !ok { + casemap = casemapRFC1459 + } + uc.network.updateCasemapping(casemap) + uc.nickCM = uc.network.casemap(uc.nick) + uc.casemapIsSet = true + case "CHANMODES": + if !negate { + err = uc.handleChanModes(value) + } else { + uc.availableChannelModes = stdChannelModes + } + case "CHANTYPES": + if !negate { + uc.availableChannelTypes = value + } else { + uc.availableChannelTypes = stdChannelTypes + } + case "PREFIX": + if !negate { + err = uc.handleMemberships(value) + } else { + uc.availableMemberships = stdMemberships + } + } + if err != nil { + return err + } + + if passthroughIsupport[parameter] { + downstreamIsupport = append(downstreamIsupport, token) + } + } + + uc.updateMonitor() + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil { + return + } + msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport) + for _, msg := range msgs { + dc.SendMessage(msg) + } + }) + case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD: + if !uc.casemapIsSet { + // upstream did not send any CASEMAPPING token, thus + // we assume it implements the old RFCs with rfc1459. + uc.casemapIsSet = true + uc.network.updateCasemapping(casemapRFC1459) + uc.nickCM = uc.network.casemap(uc.nick) + } + + if !uc.gotMotd { + // Ignore the initial MOTD upon connection, but forward + // subsequent MOTD messages downstream + uc.gotMotd = true + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case "BATCH": + var tag string + if err := parseMessageParams(msg, &tag); err != nil { + return err + } + + if strings.HasPrefix(tag, "+") { + tag = tag[1:] + if _, ok := uc.batches[tag]; ok { + return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag) + } + var batchType string + if err := parseMessageParams(msg, nil, &batchType); err != nil { + return err + } + label := label + if label == "" && msgBatch != nil { + label = msgBatch.Label + } + uc.batches[tag] = batch{ + Type: batchType, + Params: msg.Params[2:], + Outer: msgBatch, + Label: label, + } + } else if strings.HasPrefix(tag, "-") { + tag = tag[1:] + if _, ok := uc.batches[tag]; !ok { + return fmt.Errorf("unknown BATCH reference tag: %q", tag) + } + delete(uc.batches, tag) + } else { + return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag) + } + case "NICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newNick string + if err := parseMessageParams(msg, &newNick); err != nil { + return err + } + + me := false + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick) + me = true + uc.nick = newNick + uc.nickCM = uc.network.casemap(uc.nick) + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + memberships := ch.Members.Value(msg.Prefix.Name) + if memberships != nil { + ch.Members.Delete(msg.Prefix.Name) + ch.Members.SetValue(newNick, memberships) + uc.appendLog(ch.Name, msg) + } + } + + if !me { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateNick() + }) + uc.updateMonitor() + } + case "SETNAME": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newRealname string + if err := parseMessageParams(msg, &newRealname); err != nil { + return err + } + + // TODO: consider appending this message to logs + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname) + uc.realname = newRealname + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateRealname() + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case "JOIN": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("joined channel %q", ch) + members := membershipsCasemapMap{newCasemapMap(0)} + members.casemap = uc.network.casemap + uc.channels.SetValue(ch, &upstreamChannel{ + Name: ch, + conn: uc, + Members: members, + }) + uc.updateChannelAutoDetach(ch) + + uc.SendMessage(ctx, &irc.Message{ + Command: "MODE", + Params: []string{ch}, + }) + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.SetValue(msg.Prefix.Name, &memberships{}) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "PART": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("parted channel %q", ch) + uch := uc.channels.Value(ch) + if uch != nil { + uc.channels.Delete(ch) + uch.updateAutoDetach(0) + } + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.Delete(msg.Prefix.Name) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "KICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channel, user string + if err := parseMessageParams(msg, &channel, &user); err != nil { + return err + } + + if uc.isOurNick(user) { + uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) + uc.channels.Delete(channel) + } else { + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + ch.Members.Delete(user) + } + + uc.produce(channel, msg, nil) + case "QUIT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("quit") + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if ch.Members.Has(msg.Prefix.Name) { + ch.Members.Delete(msg.Prefix.Name) + + uc.appendLog(ch.Name, msg) + } + } + + if msg.Prefix.Name != uc.nick { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case irc.RPL_TOPIC, irc.RPL_NOTOPIC: + var name, topic string + if err := parseMessageParams(msg, nil, &name, &topic); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if msg.Command == irc.RPL_TOPIC { + ch.Topic = topic + } else { + ch.Topic = "" + } + case "TOPIC": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if len(msg.Params) > 1 { + ch.Topic = msg.Params[1] + ch.TopicWho = msg.Prefix.Copy() + ch.TopicTime = time.Now() // TODO use msg.Tags["time"] + } else { + ch.Topic = "" + } + uc.produce(ch.Name, msg, nil) + case "MODE": + var name, modeStr string + if err := parseMessageParams(msg, &name, &modeStr); err != nil { + return err + } + + if !uc.isChannel(name) { // user mode change + if name != uc.nick { + return fmt.Errorf("received MODE message for unknown nick %q", name) + } + + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + } else { // channel mode change + ch, err := uc.getChannel(name) + if err != nil { + return err + } + + needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:]) + if err != nil { + return err + } + + uc.appendLog(ch.Name, msg) + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + params := make([]string, len(msg.Params)) + params[0] = dc.marshalEntity(uc.network, name) + params[1] = modeStr + + copy(params[2:], msg.Params[2:]) + for i, modeParam := range params[2:] { + if _, ok := needMarshaling[i]; ok { + params[2+i] = dc.marshalEntity(uc.network, modeParam) + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "MODE", + Params: params, + }) + }) + } + } + case irc.RPL_UMODEIS: + if err := parseMessageParams(msg, nil); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + uc.modes = "" + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + case irc.RPL_CHANNELMODEIS: + var channel string + if err := parseMessageParams(msg, nil, &channel); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 2 { + modeStr = msg.Params[2] + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstMode := ch.modes == nil + ch.modes = make(map[byte]string) + if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil { + return err + } + + c := uc.network.channels.Value(channel) + if firstMode && (c == nil || !c.Detached) { + modeStr, modeParams := ch.modes.Format() + + uc.forEachDownstream(func(dc *downstreamConn) { + params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + }) + } + case rpl_creationtime: + var channel, creationTime string + if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstCreationTime := ch.creationTime == "" + ch.creationTime = creationTime + + c := uc.network.channels.Value(channel) + if firstCreationTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime}, + }) + }) + } + case rpl_topicwhotime: + var channel, who, timeStr string + if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstTopicWhoTime := ch.TopicWho == nil + ch.TopicWho = irc.ParsePrefix(who) + sec, err := strconv.ParseInt(timeStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse topic time: %v", err) + } + ch.TopicTime = time.Unix(sec, 0) + + c := uc.network.channels.Value(channel) + if firstTopicWhoTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{ + dc.nick, + dc.marshalEntity(uc.network, ch.Name), + topicWho.String(), + timeStr, + }, + }) + }) + } + case irc.RPL_LIST: + var channel, clients, topic string + if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LIST, + Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic}, + }) + case irc.RPL_LISTEND: + dc, cmd := uc.dequeueCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "End of /LIST"}, + }) + case irc.RPL_NAMREPLY: + var name, statusStr, members string + if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + members := splitSpace(members) + for i, member := range members { + memberships, nick := uc.parseMembershipPrefix(member) + members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick) + } + memberStr := strings.Join(members, " ") + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, statusStr, channel, memberStr}, + }) + }) + return nil + } + + status, err := parseChannelStatus(statusStr) + if err != nil { + return err + } + ch.Status = status + + for _, s := range splitSpace(members) { + memberships, nick := uc.parseMembershipPrefix(s) + ch.Members.SetValue(nick, memberships) + } + case irc.RPL_ENDOFNAMES: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, channel, "End of /NAMES list"}, + }) + }) + return nil + } + + if ch.complete { + return fmt.Errorf("received unexpected RPL_ENDOFNAMES") + } + ch.complete = true + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + forwardChannel(ctx, dc, ch) + }) + } + case irc.RPL_WHOREPLY: + var channel, username, host, server, nick, flags, trailing string + if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO") + } else if dc == nil { + return nil + } + + if channel != "*" { + channel = dc.marshalEntity(uc.network, channel) + } + nick = dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOREPLY, + Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing}, + }) + case rpl_whospcrpl: + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO") + } else if dc == nil { + return nil + } + + // Only supported in single-upstream mode, so forward as-is + dc.SendMessage(msg) + case irc.RPL_ENDOFWHO: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + dc, cmd := uc.dequeueCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO") + } else if dc == nil { + return nil + } + + mask := "*" + if len(cmd.Params) > 0 { + mask = cmd.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "End of /WHO list"}, + }) + case irc.RPL_WHOISUSER: + var nick, username, host, realname string + if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, nick, username, host, "*", realname}, + }) + }) + case irc.RPL_WHOISSERVER: + var nick, server, serverInfo string + if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, nick, server, serverInfo}, + }) + }) + case irc.RPL_WHOISOPERATOR: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, nick, "is an IRC operator"}, + }) + }) + case irc.RPL_WHOISIDLE: + var nick string + if err := parseMessageParams(msg, nil, &nick, nil); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + params := []string{dc.nick, nick} + params = append(params, msg.Params[2:]...) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISIDLE, + Params: params, + }) + }) + case irc.RPL_WHOISCHANNELS: + var nick, channelList string + if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil { + return err + } + channels := splitSpace(channelList) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + channelList := make([]string, len(channels)) + for i, channel := range channels { + prefix, channel := uc.parseMembershipPrefix(channel) + channel = dc.marshalEntity(uc.network, channel) + channelList[i] = prefix.Format(dc) + channel + } + channels := strings.Join(channelList, " ") + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISCHANNELS, + Params: []string{dc.nick, nick, channels}, + }) + }) + case irc.RPL_ENDOFWHOIS: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, nick, "End of /WHOIS list"}, + }) + }) + case "INVITE": + var nick, channel string + if err := parseMessageParams(msg, &nick, &channel); err != nil { + return err + } + + weAreInvited := uc.isOurNick(nick) + + uc.forEachDownstream(func(dc *downstreamConn) { + if !weAreInvited && !dc.caps["invite-notify"] { + return + } + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "INVITE", + Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_INVITING: + var nick, channel string + if err := parseMessageParams(msg, nil, &nick, &channel); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_INVITING, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: + var targetsStr string + if err := parseMessageParams(msg, nil, &targetsStr); err != nil { + return err + } + targets := strings.Split(targetsStr, ",") + + online := msg.Command == irc.RPL_MONONLINE + for _, target := range targets { + prefix := irc.ParsePrefix(target) + uc.monitored.SetValue(prefix.Name, online) + } + + // Check if the nick we want is now free + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if !online && uc.nickCM != wantNickCM { + found := false + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if uc.network.casemap(prefix.Name) == wantNickCM { + found = true + break + } + } + if found { + uc.logger.Printf("desired nick %q is now available", wantNick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{wantNick}, + }) + } + } + + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if dc.monitored.Has(prefix.Name) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, target}, + }) + } + } + }) + case irc.ERR_MONLISTFULL: + var limit, targetsStr string + if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil { + return err + } + + targets := strings.Split(targetsStr, ",") + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + if dc.monitored.Has(target) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, limit, target}, + }) + } + } + }) + case irc.RPL_AWAY: + var nick, reason string + if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_AWAY, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason}, + }) + }) + case "AWAY", "ACCOUNT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST: + var channel, mask string + if err := parseMessageParams(msg, nil, &channel, &mask); err != nil { + return err + } + var addNick, addTime string + if len(msg.Params) >= 5 { + addNick = msg.Params[3] + addTime = msg.Params[4] + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, channel) + + var params []string + if addNick != "" && addTime != "" { + addNick := dc.marshalEntity(uc.network, addNick) + params = []string{dc.nick, channel, mask, addNick, addTime} + } else { + params = []string{dc.nick, channel, mask} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST: + var channel, trailing string + if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + upstreamChannel := dc.marshalEntity(uc.network, channel) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, upstreamChannel, trailing}, + }) + }) + case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN: + var command, reason string + if err := parseMessageParams(msg, nil, &command, &reason); err != nil { + return err + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, command, reason}, + }) + }) + case "FAIL": + var command, code string + if err := parseMessageParams(msg, &command, &code); err != nil { + return err + } + + if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" { + return registrationError{msg} + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case "ACK": + // Ignore + case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: + // Ignore + case irc.RPL_YOURHOST, irc.RPL_CREATED: + // Ignore + case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: + fallthrough + case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: + fallthrough + case rpl_localusers, rpl_globalusers: + fallthrough + case irc.RPL_MOTDSTART, irc.RPL_MOTD: + // Ignore these messages if they're part of the initial registration + // message burst. Forward them if the user explicitly asked for them. + if !uc.gotMotd { + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_LISTSTART: + // Ignore + case "ERROR": + var text string + if err := parseMessageParams(msg, &text); err != nil { + return err + } + return fmt.Errorf("fatal server error: %v", text) + case irc.ERR_NICKNAMEINUSE: + // At this point, we haven't received ISUPPORT so we don't know the + // maximum nickname length or whether the server supports MONITOR. Many + // servers have NICKLEN=30 so let's just use that. + if !uc.registered && len(uc.nick)+1 < 30 { + uc.nick = uc.nick + "_" + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + return nil + } + fallthrough + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP: + if !uc.registered { + return registrationError{msg} + } + fallthrough + default: + uc.logger.Printf("unhandled message: %v", msg) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + // best effort marshaling for unknown messages, replies and errors: + // most numerics start with the user nick, marshal it if that's the case + // otherwise, conservately keep the params without marshaling + params := msg.Params + if _, err := strconv.Atoi(msg.Command); err == nil { // numeric + if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) { + params[0] = dc.nick + } + } + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + } + return nil +} + +func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) { + if uc.network.detachedMessageNeedsRelay(ch, msg) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.relayDetachedMessage(uc.network, msg) + }) + } + if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { + uc.network.attach(ctx, ch) + if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) + } + } +} + +func (uc *upstreamConn) handleChanModes(s string) error { + parts := strings.SplitN(s, ",", 5) + if len(parts) < 4 { + return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s) + } + modes := make(map[byte]channelModeType) + for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} { + for j := 0; j < len(parts[i]); j++ { + mode := parts[i][j] + modes[mode] = mt + } + } + uc.availableChannelModes = modes + return nil +} + +func (uc *upstreamConn) handleMemberships(s string) error { + if s == "" { + uc.availableMemberships = nil + return nil + } + + if s[0] != '(' { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + sep := strings.IndexByte(s, ')') + if sep < 0 || len(s) != sep*2 { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + memberships := make([]membership, len(s)/2-1) + for i := range memberships { + memberships[i] = membership{ + Mode: s[i+1], + Prefix: s[sep+i+1], + } + } + uc.availableMemberships = memberships + return nil +} + +func (uc *upstreamConn) handleSupportedCaps(capsStr string) { + caps := strings.Fields(capsStr) + for _, s := range caps { + kv := strings.SplitN(s, "=", 2) + k := strings.ToLower(kv[0]) + var v string + if len(kv) == 2 { + v = kv[1] + } + uc.supportedCaps[k] = v + } +} + +func (uc *upstreamConn) requestCaps() { + var requestCaps []string + for c := range permanentUpstreamCaps { + if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { + requestCaps = append(requestCaps, c) + } + } + + if len(requestCaps) == 0 { + return + } + + uc.SendMessage(context.TODO(), &irc.Message{ + Command: "CAP", + Params: []string{"REQ", strings.Join(requestCaps, " ")}, + }) +} + +func (uc *upstreamConn) supportsSASL(mech string) bool { + v, ok := uc.supportedCaps["sasl"] + if !ok { + return false + } + + if v == "" { + return true + } + + mechanisms := strings.Split(v, ",") + for _, mech := range mechanisms { + if strings.EqualFold(mech, mech) { + return true + } + } + return false +} + +func (uc *upstreamConn) requestSASL() bool { + if uc.network.SASL.Mechanism == "" { + return false + } + return uc.supportsSASL(uc.network.SASL.Mechanism) +} + +func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error { + uc.caps[name] = ok + + switch name { + case "sasl": + if !uc.requestSASL() { + return nil + } + if !ok { + uc.logger.Printf("server refused to acknowledge the SASL capability") + return nil + } + + auth := &uc.network.SASL + switch auth.Mechanism { + case "PLAIN": + uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) + uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password) + case "EXTERNAL": + uc.logger.Printf("starting SASL EXTERNAL authentication") + uc.saslClient = sasl.NewExternalClient("") + default: + return fmt.Errorf("unsupported SASL mechanism %q", name) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{auth.Mechanism}, + }) + default: + if permanentUpstreamCaps[name] { + break + } + uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name) + } + return nil +} + +func splitSpace(s string) []string { + return strings.FieldsFunc(s, func(r rune) bool { + return r == ' ' + }) +} + +func (uc *upstreamConn) register(ctx context.Context) { + uc.nick = GetNick(&uc.user.User, &uc.network.Network) + uc.nickCM = uc.network.casemap(uc.nick) + uc.username = GetUsername(&uc.user.User, &uc.network.Network) + uc.realname = GetRealname(&uc.user.User, &uc.network.Network) + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"LS", "302"}, + }) + + if uc.network.Pass != "" { + uc.SendMessage(ctx, &irc.Message{ + Command: "PASS", + Params: []string{uc.network.Pass}, + }) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + uc.SendMessage(ctx, &irc.Message{ + Command: "USER", + Params: []string{uc.username, "0", "*", uc.realname}, + }) +} + +func (uc *upstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := uc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error { + for !uc.registered { + msg, err := uc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read message: %v", err) + } + + if err := uc.handleMessage(ctx, msg); err != nil { + if _, ok := err.(registrationError); ok { + return err + } else { + msg.Tags = nil // prevent message tags from cluttering logs + return fmt.Errorf("failed to handle message %q: %v", msg, err) + } + } + } + + for _, command := range uc.network.ConnectCommands { + m, err := irc.ParseMessage(command) + if err != nil { + uc.logger.Printf("failed to parse connect command %q: %v", command, err) + } else { + uc.SendMessage(ctx, m) + } + } + + return nil +} + +func (uc *upstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := uc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventUpstreamMessage{msg, uc} + } + + return nil +} + +func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { + if !uc.caps["message-tags"] { + msg = msg.Copy() + msg.Tags = nil + } + + uc.conn.SendMessage(ctx, msg) +} + +func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { + if uc.caps["labeled-response"] { + if msg.Tags == nil { + msg.Tags = make(map[string]irc.TagValue) + } + msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID)) + uc.nextLabelID++ + } + uc.SendMessage(ctx, msg) +} + +// appendLog appends a message to the log file. +// +// The internal message ID is returned. If the message isn't recorded in the +// log file, an empty string is returned. +func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) { + if uc.user.msgStore == nil { + return "" + } + + // Don't store messages with a server mask target + if strings.HasPrefix(entity, "$") { + return "" + } + + entityCM := uc.network.casemap(entity) + if entityCM == "nickserv" { + // The messages sent/received from NickServ may contain + // security-related information (like passwords). Don't store these. + return "" + } + + if !uc.network.delivered.HasTarget(entity) { + // This is the first message we receive from this target. Save the last + // message ID in delivery receipts, so that we can send the new message + // in the backlog if an offline client reconnects. + lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now()) + if err != nil { + uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) + return "" + } + + uc.network.delivered.ForEachClient(func(clientName string) { + uc.network.delivered.StoreID(entity, clientName, lastID) + }) + } + + msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg) + if err != nil { + uc.logger.Printf("failed to append message to store: %v", err) + return "" + } + + return msgID +} + +// produce appends a message to the logs and forwards it to connected downstream +// connections. +// +// If origin is not nil and origin doesn't support echo-message, the message is +// forwarded to all connections except origin. +func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) { + var msgID string + if target != "" { + msgID = uc.appendLog(target, msg) + } + + // Don't forward messages if it's a detached channel + ch := uc.network.channels.Value(target) + detached := ch != nil && ch.Detached + + uc.forEachDownstream(func(dc *downstreamConn) { + if !detached && (dc != origin || dc.caps["echo-message"]) { + dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID) + } else { + dc.advanceMessageWithID(msg, msgID) + } + }) +} + +func (uc *upstreamConn) updateAway() { + ctx := context.TODO() + + away := true + uc.forEachDownstream(func(*downstreamConn) { + away = false + }) + if away == uc.away { + return + } + if away { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + Params: []string{"Auto away"}, + }) + } else { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + }) + } + uc.away = away +} + +func (uc *upstreamConn) updateChannelAutoDetach(name string) { + uch := uc.channels.Value(name) + if uch == nil { + return + } + ch := uc.network.channels.Value(name) + if ch == nil || ch.Detached { + return + } + uch.updateAutoDetach(ch.DetachAfter) +} + +func (uc *upstreamConn) updateMonitor() { + if _, ok := uc.isupport["MONITOR"]; !ok { + return + } + + ctx := context.TODO() + + add := make(map[string]struct{}) + var addList []string + seen := make(map[string]struct{}) + uc.forEachDownstream(func(dc *downstreamConn) { + for targetCM := range dc.monitored.innerMap { + if !uc.monitored.Has(targetCM) { + if _, ok := add[targetCM]; !ok { + addList = append(addList, targetCM) + add[targetCM] = struct{}{} + } + } else { + seen[targetCM] = struct{}{} + } + } + }) + + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) { + addList = append(addList, wantNickCM) + add[wantNickCM] = struct{}{} + } + + removeAll := true + var removeList []string + for targetCM, entry := range uc.monitored.innerMap { + if _, ok := seen[targetCM]; ok { + removeAll = false + } else { + removeList = append(removeList, entry.originalKey) + } + } + + // TODO: better handle the case where len(uc.monitored) + len(addList) + // exceeds the limit, probably by immediately sending ERR_MONLISTFULL? + + if removeAll && len(addList) == 0 && len(removeList) > 0 { + // Optimization when the last MONITOR-aware downstream disconnects + uc.SendMessage(ctx, &irc.Message{ + Command: "MONITOR", + Params: []string{"C"}, + }) + } else { + msgs := generateMonitor("-", removeList) + msgs = append(msgs, generateMonitor("+", addList)...) + for _, msg := range msgs { + uc.SendMessage(ctx, msg) + } + } + + for _, target := range removeList { + uc.monitored.Delete(target) + } +} diff --git a/branches/master/user.go b/branches/master/user.go new file mode 100644 index 0000000..0b523ce --- /dev/null +++ b/branches/master/user.go @@ -0,0 +1,1068 @@ +package suika + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "math/big" + "net" + "sort" + "strings" + "time" + + "gopkg.in/irc.v3" +) + +type event interface{} + +type eventUpstreamMessage struct { + msg *irc.Message + uc *upstreamConn +} + +type eventUpstreamConnectionError struct { + net *network + err error +} + +type eventUpstreamConnected struct { + uc *upstreamConn +} + +type eventUpstreamDisconnected struct { + uc *upstreamConn +} + +type eventUpstreamError struct { + uc *upstreamConn + err error +} + +type eventDownstreamMessage struct { + msg *irc.Message + dc *downstreamConn +} + +type eventDownstreamConnected struct { + dc *downstreamConn +} + +type eventDownstreamDisconnected struct { + dc *downstreamConn +} + +type eventChannelDetach struct { + uc *upstreamConn + name string +} + +type eventBroadcast struct { + msg *irc.Message +} + +type eventStop struct{} + +type eventUserUpdate struct { + password *string + admin *bool + done chan error +} + +type deliveredClientMap map[string]string // client name -> msg ID + +type deliveredStore struct { + m deliveredCasemapMap +} + +func newDeliveredStore() deliveredStore { + return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}} +} + +func (ds deliveredStore) HasTarget(target string) bool { + return ds.m.Value(target) != nil +} + +func (ds deliveredStore) LoadID(target, clientName string) string { + clients := ds.m.Value(target) + if clients == nil { + return "" + } + return clients[clientName] +} + +func (ds deliveredStore) StoreID(target, clientName, msgID string) { + clients := ds.m.Value(target) + if clients == nil { + clients = make(deliveredClientMap) + ds.m.SetValue(target, clients) + } + clients[clientName] = msgID +} + +func (ds deliveredStore) ForEachTarget(f func(target string)) { + for _, entry := range ds.m.innerMap { + f(entry.originalKey) + } +} + +func (ds deliveredStore) ForEachClient(f func(clientName string)) { + clients := make(map[string]struct{}) + for _, entry := range ds.m.innerMap { + delivered := entry.value.(deliveredClientMap) + for clientName := range delivered { + clients[clientName] = struct{}{} + } + } + + for clientName := range clients { + f(clientName) + } +} + +type network struct { + Network + user *user + logger Logger + stopped chan struct{} + + conn *upstreamConn + channels channelCasemapMap + delivered deliveredStore + lastError error + casemap casemapping +} + +func newNetwork(user *user, record *Network, channels []Channel) *network { + logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} + + m := channelCasemapMap{newCasemapMap(0)} + for _, ch := range channels { + ch := ch + m.SetValue(ch.Name, &ch) + } + + return &network{ + Network: *record, + user: user, + logger: logger, + stopped: make(chan struct{}), + channels: m, + delivered: newDeliveredStore(), + casemap: casemapRFC1459, + } +} + +func (net *network) forEachDownstream(f func(*downstreamConn)) { + net.user.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + if dc.network != nil && dc.network != net { + return + } + f(dc) + }) +} + +func (net *network) isStopped() bool { + select { + case <-net.stopped: + return true + default: + return false + } +} + +func userIdent(u *User) string { + // The ident is a string we will send to upstream servers in clear-text. + // For privacy reasons, make sure it doesn't expose any meaningful user + // metadata. We just use the base64-encoded hashed ID, so that people don't + // start relying on the string being an integer or following a pattern. + var b [64]byte + binary.LittleEndian.PutUint64(b[:], uint64(u.ID)) + h := sha256.Sum256(b[:]) + return hex.EncodeToString(h[:16]) +} + +func (net *network) run() { + if !net.Enabled { + return + } + + var lastTry time.Time + backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter) + for { + if net.isStopped() { + return + } + + delay := backoff.Next() - time.Now().Sub(lastTry) + if delay > 0 { + net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr) + time.Sleep(delay) + } + lastTry = time.Now() + + + uc, err := connectToUpstream(context.TODO(), net) + if err != nil { + net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)} + continue + } + + uc.register(context.TODO()) + if err := uc.runUntilRegistered(context.TODO()); err != nil { + text := err.Error() + temp := true + if regErr, ok := err.(registrationError); ok { + text = regErr.Reason() + temp = regErr.Temporary() + } + uc.logger.Printf("failed to register: %v", text) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)} + uc.Close() + if !temp { + return + } + continue + } + + // TODO: this is racy with net.stopped. If the network is stopped + // before the user goroutine receives eventUpstreamConnected, the + // connection won't be closed. + net.user.events <- eventUpstreamConnected{uc} + if err := uc.readMessages(net.user.events); err != nil { + uc.logger.Printf("failed to handle messages: %v", err) + net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)} + } + uc.Close() + net.user.events <- eventUpstreamDisconnected{uc} + + backoff.Reset() + } +} + +func (net *network) stop() { + if !net.isStopped() { + close(net.stopped) + } + + if net.conn != nil { + net.conn.Close() + } +} + +func (net *network) detach(ch *Channel) { + if ch.Detached { + return + } + + net.logger.Printf("detaching channel %q", ch.Name) + + ch.Detached = true + + if net.user.msgStore != nil { + nameCM := net.casemap(ch.Name) + lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now()) + if err != nil { + net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err) + } + ch.DetachedInternalMsgID = lastID + } + + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "PART", + Params: []string{dc.marshalEntity(net, ch.Name), "Detach"}, + }) + }) +} + +func (net *network) attach(ctx context.Context, ch *Channel) { + if !ch.Detached { + return + } + + net.logger.Printf("attaching channel %q", ch.Name) + + detachedMsgID := ch.DetachedInternalMsgID + ch.Detached = false + ch.DetachedInternalMsgID = "" + + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + + net.conn.updateChannelAutoDetach(ch.Name) + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(net, ch.Name)}, + }) + + if uch != nil { + forwardChannel(ctx, dc, uch) + } + + if detachedMsgID != "" { + dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID) + } + }) +} + +func (net *network) deleteChannel(ctx context.Context, name string) error { + ch := net.channels.Value(name) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil { + return err + } + net.channels.Delete(name) + return nil +} + +func (net *network) updateCasemapping(newCasemap casemapping) { + net.casemap = newCasemap + net.channels.SetCasemapping(newCasemap) + net.delivered.m.SetCasemapping(newCasemap) + if uc := net.conn; uc != nil { + uc.channels.SetCasemapping(newCasemap) + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.Members.SetCasemapping(newCasemap) + } + uc.monitored.SetCasemapping(newCasemap) + } + net.forEachDownstream(func(dc *downstreamConn) { + dc.monitored.SetCasemapping(newCasemap) + }) +} + +func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) { + if !net.user.hasPersistentMsgStore() { + return + } + + var receipts []DeliveryReceipt + net.delivered.ForEachTarget(func(target string) { + msgID := net.delivered.LoadID(target, clientName) + if msgID == "" { + return + } + receipts = append(receipts, DeliveryReceipt{ + Target: target, + InternalMsgID: msgID, + }) + }) + + if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil { + net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err) + } +} + +func (net *network) isHighlight(msg *irc.Message) bool { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return false + } + + text := msg.Params[1] + + nick := net.Nick + if net.conn != nil { + nick = net.conn.nick + } + + // TODO: use case-mapping aware comparison here + return msg.Prefix.Name != nick && isHighlight(text, nick) +} + +func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool { + highlight := net.isHighlight(msg) + return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight) +} + +func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) { + // User may have e.g. EXTERNAL mechanism configured. We do not want to + // automatically erase the key pair or any other credentials. + if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" { + return + } + + net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username) + net.SASL.Mechanism = "PLAIN" + net.SASL.Plain.Username = username + net.SASL.Plain.Password = password + if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil { + net.logger.Printf("failed to save SASL PLAIN credentials: %v", err) + } +} + +type user struct { + User + srv *Server + logger Logger + + events chan event + done chan struct{} + + networks []*network + downstreamConns []*downstreamConn + msgStore messageStore +} + +func newUser(srv *Server, record *User) *user { + logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} + + var msgStore messageStore + if logPath := srv.Config().LogPath; logPath != "" { + msgStore = newFSMessageStore(logPath, record) + } else { + msgStore = newMemoryMessageStore() + } + + return &user{ + User: *record, + srv: srv, + logger: logger, + events: make(chan event, 64), + done: make(chan struct{}), + msgStore: msgStore, + } +} + +func (u *user) forEachUpstream(f func(uc *upstreamConn)) { + for _, network := range u.networks { + if network.conn == nil { + continue + } + f(network.conn) + } +} + +func (u *user) forEachDownstream(f func(dc *downstreamConn)) { + for _, dc := range u.downstreamConns { + f(dc) + } +} + +func (u *user) getNetwork(name string) *network { + for _, network := range u.networks { + if network.Addr == name { + return network + } + if network.Name != "" && network.Name == name { + return network + } + } + return nil +} + +func (u *user) getNetworkByID(id int64) *network { + for _, net := range u.networks { + if net.ID == id { + return net + } + } + return nil +} + +func (u *user) run() { + defer func() { + if u.msgStore != nil { + if err := u.msgStore.Close(); err != nil { + u.logger.Printf("failed to close message store for user %q: %v", u.Username, err) + } + } + close(u.done) + }() + + networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID) + if err != nil { + u.logger.Printf("failed to list networks for user %q: %v", u.Username, err) + return + } + + sort.Slice(networks, func(i, j int) bool { + return networks[i].ID < networks[j].ID + }) + + for _, record := range networks { + record := record + channels, err := u.srv.db.ListChannels(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err) + continue + } + + network := newNetwork(u, &record, channels) + u.networks = append(u.networks, network) + + if u.hasPersistentMsgStore() { + receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err) + return + } + + for _, rcpt := range receipts { + network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID) + } + } + + go network.run() + } + + for e := range u.events { + switch e := e.(type) { + case eventUpstreamConnected: + uc := e.uc + + uc.network.conn = uc + + uc.updateAway() + uc.updateMonitor() + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + }) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=connected"}, + }) + } + }) + uc.network.lastError = nil + case eventUpstreamDisconnected: + u.handleUpstreamDisconnected(e.uc) + case eventUpstreamConnectionError: + net := e.net + + stopped := false + select { + case <-net.stopped: + stopped = true + default: + } + + if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) { + net.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err)) + }) + } + net.lastError = e.err + case eventUpstreamError: + uc := e.uc + + uc.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err)) + }) + uc.network.lastError = e.err + case eventUpstreamMessage: + msg, uc := e.msg, e.uc + if uc.isClosed() { + uc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + if err := uc.handleMessage(context.TODO(), msg); err != nil { + uc.logger.Printf("failed to handle message %q: %v", msg, err) + } + case eventChannelDetach: + uc, name := e.uc, e.name + c := uc.network.channels.Value(name) + if c == nil || c.Detached { + continue + } + uc.network.detach(c) + if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil { + u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err) + } + case eventDownstreamConnected: + dc := e.dc + + if dc.network != nil { + dc.monitored.SetCasemapping(dc.network.casemap) + } + + if err := dc.welcome(context.TODO()); err != nil { + dc.logger.Printf("failed to handle new registered connection: %v", err) + break + } + + u.downstreamConns = append(u.downstreamConns, dc) + + dc.forEachNetwork(func(network *network) { + if network.lastError != nil { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError)) + } + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.updateAway() + }) + case eventDownstreamDisconnected: + dc := e.dc + + for i := range u.downstreamConns { + if u.downstreamConns[i] == dc { + u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) + break + } + } + + dc.forEachNetwork(func(net *network) { + net.storeClientDeliveryReceipts(context.TODO(), dc.clientName) + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.cancelPendingCommandsByDownstreamID(dc.id) + uc.updateAway() + uc.updateMonitor() + }) + case eventDownstreamMessage: + msg, dc := e.msg, e.dc + if dc.isClosed() { + dc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + err := dc.handleMessage(context.TODO(), msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + dc.logger.Printf("failed to handle message %q: %v", msg, err) + dc.Close() + } + case eventBroadcast: + msg := e.msg + u.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case eventUserUpdate: + // copy the user record because we'll mutate it + record := u.User + + if e.password != nil { + record.Password = *e.password + } + if e.admin != nil { + record.Admin = *e.admin + } + + e.done <- u.updateUser(context.TODO(), &record) + + // If the password was updated, kill all downstream connections to + // force them to re-authenticate with the new credentials. + if e.password != nil { + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + } + case eventStop: + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + for _, n := range u.networks { + n.stop() + + n.delivered.ForEachClient(func(clientName string) { + n.storeClientDeliveryReceipts(context.TODO(), clientName) + }) + } + return + default: + panic(fmt.Sprintf("received unknown event type: %T", e)) + } + } +} + +func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { + uc.network.conn = nil + + uc.abortPendingCommands() + + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.updateAutoDetach(0) + } + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + + // If the network has been removed, don't send a state change notification + found := false + for _, net := range u.networks { + if net == uc.network { + found = true + break + } + } + if !found { + return + } + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=disconnected"}, + }) + } + }) + + if uc.network.lastError == nil { + uc.forEachDownstream(func(dc *downstreamConn) { + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) + } + }) + } +} + +func (u *user) addNetwork(network *network) { + u.networks = append(u.networks, network) + + sort.Slice(u.networks, func(i, j int) bool { + return u.networks[i].ID < u.networks[j].ID + }) + + go network.run() +} + +func (u *user) removeNetwork(network *network) { + network.stop() + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.Close() + } + }) + + for i, net := range u.networks { + if net == network { + u.networks = append(u.networks[:i], u.networks[i+1:]...) + return + } + } + + panic("tried to remove a non-existing network") +} + +func (u *user) checkNetwork(record *Network) error { + url, err := record.URL() + if err != nil { + return err + } + if url.User != nil { + return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme) + } + if url.RawQuery != "" { + return fmt.Errorf("%v:// URL must not have query values", url.Scheme) + } + if url.Fragment != "" { + return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme) + } + switch url.Scheme { + case "ircs", "irc": + if url.Host == "" { + return fmt.Errorf("%v:// URL must have a host", url.Scheme) + } + if url.Path != "" { + return fmt.Errorf("%v:// URL must not have a path", url.Scheme) + } + case "irc+unix", "unix": + if url.Host != "" { + return fmt.Errorf("%v:// URL must not have a host", url.Scheme) + } + if url.Path == "" { + return fmt.Errorf("%v:// URL must have a path", url.Scheme) + } + default: + return fmt.Errorf("unknown URL scheme %q", url.Scheme) + } + + if record.GetName() == "" { + return fmt.Errorf("network name cannot be empty") + } + if strings.HasPrefix(record.GetName(), "-") { + // Can be mixed up with flags when sending commands to the service + return fmt.Errorf("network name cannot start with a dash character") + } + + for _, net := range u.networks { + if net.GetName() == record.GetName() && net.ID != record.ID { + return fmt.Errorf("a network with the name %q already exists", record.GetName()) + } + } + + return nil +} + +func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID != 0 { + panic("tried creating an already-existing network") + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max { + return nil, fmt.Errorf("maximum number of networks reached") + } + + network := newNetwork(u, record, nil) + err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network) + if err != nil { + return nil, err + } + + u.addNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return network, nil +} + +func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID == 0 { + panic("tried updating a new network") + } + + // If the realname is reset to the default, just wipe the per-network + // setting + if record.Realname == u.Realname { + record.Realname = "" + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + network := u.getNetworkByID(record.ID) + if network == nil { + panic("tried updating a non-existing network") + } + + if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil { + return nil, err + } + + // Most network changes require us to re-connect to the upstream server + + channels := make([]Channel, 0, network.channels.Len()) + for _, entry := range network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, *ch) + } + + updatedNetwork := newNetwork(u, record, channels) + + // If we're currently connected, disconnect and perform the necessary + // bookkeeping + if network.conn != nil { + network.stop() + // Note: this will set network.conn to nil + u.handleUpstreamDisconnected(network.conn) + } + + // Patch downstream connections to use our fresh updated network + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.network = updatedNetwork + } + }) + + // We need to remove the network after patching downstream connections, + // otherwise they'll get closed + u.removeNetwork(network) + + // The filesystem message store needs to be notified whenever the network + // is renamed + fsMsgStore, isFS := u.msgStore.(*fsMessageStore) + if isFS && updatedNetwork.GetName() != network.GetName() { + if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil { + network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err) + } + } + + // This will re-connect to the upstream server + u.addNetwork(updatedNetwork) + + // TODO: only broadcast attributes that have changed + idStr := fmt.Sprintf("%v", updatedNetwork.ID) + attrs := getNetworkAttrs(updatedNetwork) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return updatedNetwork, nil +} + +func (u *user) deleteNetwork(ctx context.Context, id int64) error { + network := u.getNetworkByID(id) + if network == nil { + panic("tried deleting a non-existing network") + } + + if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil { + return err + } + + u.removeNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, "*"}, + }) + } + }) + + return nil +} + +func (u *user) updateUser(ctx context.Context, record *User) error { + if u.ID != record.ID { + panic("ID mismatch when updating user") + } + + realnameUpdated := u.Realname != record.Realname + if err := u.srv.db.StoreUser(ctx, record); err != nil { + return fmt.Errorf("failed to update user %q: %v", u.Username, err) + } + u.User = *record + + if realnameUpdated { + // Re-connect to networks which use the default realname + var needUpdate []Network + for _, net := range u.networks { + if net.Realname == "" { + needUpdate = append(needUpdate, net.Network) + } + } + + var netErr error + for _, net := range needUpdate { + if _, err := u.updateNetwork(ctx, &net); err != nil { + netErr = err + } + } + if netErr != nil { + return netErr + } + } + + return nil +} + +func (u *user) stop() { + u.events <- eventStop{} + <-u.done +} + +func (u *user) hasPersistentMsgStore() bool { + if u.msgStore == nil { + return false + } + _, isMem := u.msgStore.(*memoryMessageStore) + return !isMem +} + +// localAddrForHost returns the local address to use when connecting to host. +// A nil address is returned when the OS should automatically pick one. +func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) { + upstreamUserIPs := u.srv.Config().UpstreamUserIPs + if len(upstreamUserIPs) == 0 { + return nil, nil + } + + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) + if err != nil { + return nil, err + } + + wantIPv6 := false + for _, ip := range ips { + if ip.To4() == nil { + wantIPv6 = true + break + } + } + + var ipNet *net.IPNet + for _, in := range upstreamUserIPs { + if wantIPv6 == (in.IP.To4() == nil) { + ipNet = in + break + } + } + if ipNet == nil { + return nil, nil + } + + var ipInt big.Int + ipInt.SetBytes(ipNet.IP) + ipInt.Add(&ipInt, big.NewInt(u.ID+1)) + ip := net.IP(ipInt.Bytes()) + if !ipNet.Contains(ip) { + return nil, fmt.Errorf("IP network %v too small", ipNet) + } + + return &net.TCPAddr{IP: ip}, nil +} diff --git a/branches/master/version.go b/branches/master/version.go new file mode 100644 index 0000000..0e5d7a6 --- /dev/null +++ b/branches/master/version.go @@ -0,0 +1,50 @@ +package suika + +import ( + "fmt" + "runtime/debug" + "strings" +) + +const ( + defaultVersion = "0.0.0" + defaultCommit = "HEAD" + defaultBuild = "0000-01-01:00:00+00:00" +) + +var ( + // Version is the tagged release version in the form .. + // following semantic versioning and is overwritten by the build system. + Version = defaultVersion + + // Commit is the commit sha of the build (normally from Git) and is overwritten + // by the build system. + Commit = defaultCommit + + // Build is the date and time of the build as an RFC3339 formatted string + // and is overwritten by the build system. + Build = defaultBuild +) + +// FullVersion display the full version and build +func FullVersion() string { + var sb strings.Builder + + isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild + + if !isDefault { + sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build)) + } + + if info, ok := debug.ReadBuildInfo(); ok { + if isDefault { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Version)) + } + sb.WriteString(fmt.Sprintf(" %s", info.GoVersion)) + if info.Main.Sum != "" { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum)) + } + } + + return sb.String() +} diff --git a/branches/origin-master/.gitignore b/branches/origin-master/.gitignore new file mode 100644 index 0000000..4aa22f5 --- /dev/null +++ b/branches/origin-master/.gitignore @@ -0,0 +1,5 @@ +vendor +/suika +/suikadb +/suika-znc-import +/suika.db diff --git a/branches/origin-master/LICENSE b/branches/origin-master/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/branches/origin-master/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/branches/origin-master/Makefile b/branches/origin-master/Makefile new file mode 100644 index 0000000..a77bc97 --- /dev/null +++ b/branches/origin-master/Makefile @@ -0,0 +1,45 @@ +GO ?= go +RM ?= rm +GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor +PREFIX ?= /usr/local +BINDIR ?= bin +MANDIR ?= share/man +MKDIR ?= mkdir +CP ?= cp +SYSCONFDIR ?= /etc +ASCIIDOCTOR ?= asciidoctor + +VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"` +COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"` +BRANCH = `git rev-parse --abbrev-ref HEAD` +BUILD = `git show -s --pretty=format:%cI` + +GOARCH ?= amd64 +GOOS ?= linux + +all: build + +build: vendor + ${GO} build ${GOFLAGS} ./cmd/suika + ${GO} build ${GOFLAGS} ./cmd/suikadb + ${GO} build ${GOFLAGS} ./cmd/suika-znc-import +clean: + ${RM} -f suika suikadb suika-znc-import +install: + ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR} + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7 + ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika + ${MKDIR} -p ${DESTDIR}/var/lib/suika + ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR} + ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1 + ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5 + [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config +test: + go test +vendor: + go mod vendor +.PHONY: build clean install diff --git a/branches/origin-master/README.md b/branches/origin-master/README.md new file mode 100644 index 0000000..2171ba5 --- /dev/null +++ b/branches/origin-master/README.md @@ -0,0 +1,33 @@ +# suika + +[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika) + +A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power) + +- Multi-user +- Support multiple clients for a single user, with proper backlog + synchronization +- Support connecting to multiple upstream servers via a single IRC connection + to the bouncer + +## Building and installing + +Dependencies: + +- Go +- BSD or GNU make + +For end users, a `Makefile` is provided: + + make + doas make install + +For development, you can use `go run ./cmd/suika` as usual. + +## License +AGPLv3, see [LICENSE](LICENSE). + +* Copyright (C) 2020 The soju Contributors +* Copyright (C) 2023-present Izuru Yakumo + +The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT diff --git a/branches/origin-master/bridge.go b/branches/origin-master/bridge.go new file mode 100644 index 0000000..7bb7aa8 --- /dev/null +++ b/branches/origin-master/bridge.go @@ -0,0 +1,116 @@ +package suika + +import ( + "context" + "fmt" + "strconv" + "strings" + + "gopkg.in/irc.v3" +) + +func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) { + if !ch.complete { + panic("Tried to forward a partial channel") + } + + // RPL_NOTOPIC shouldn't be sent on JOIN + if ch.Topic != "" { + sendTopic(dc, ch) + } + + if dc.caps["soju.im/read"] { + channelCM := ch.conn.network.casemap(ch.Name) + r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err) + } else { + timestampStr := "*" + if r != nil { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr}, + }) + } + } + + sendNames(dc, ch) +} + +func sendTopic(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + if ch.Topic != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_TOPIC, + Params: []string{dc.nick, downstreamName, ch.Topic}, + }) + if ch.TopicWho != nil { + topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho) + topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime}, + }) + } + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NOTOPIC, + Params: []string{dc.nick, downstreamName, "No topic is set"}, + }) + } +} + +func sendNames(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + emptyNameReply := &irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, ""}, + } + maxLength := maxMessageLength - len(emptyNameReply.String()) + + var buf strings.Builder + for _, entry := range ch.Members.innerMap { + nick := entry.originalKey + memberships := entry.value.(*memberships) + s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick) + + n := buf.Len() + 1 + len(s) + if buf.Len() != 0 && n > maxLength { + // There's not enough space for the next space + nick. + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + buf.Reset() + } + + if buf.Len() != 0 { + buf.WriteByte(' ') + } + buf.WriteString(s) + } + + if buf.Len() != 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, downstreamName, "End of /NAMES list"}, + }) +} diff --git a/branches/origin-master/certfp.go b/branches/origin-master/certfp.go new file mode 100644 index 0000000..9180f7c --- /dev/null +++ b/branches/origin-master/certfp.go @@ -0,0 +1,72 @@ +package suika + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" +) + +func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) { + var ( + privKey crypto.PrivateKey + pubKey crypto.PublicKey + ) + switch keyType { + case "rsa": + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ecdsa": + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ed25519": + var err error + pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + } + + // Using PKCS#8 allows easier extension for new key types. + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey) + if err != nil { + return nil, nil, err + } + + notBefore := time.Now() + // Lets make a fair assumption nobody will use the same cert for more than 20 years... + notAfter := notBefore.Add(24 * time.Hour * 365 * 20) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, err + } + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: "suika auto-generated certificate"}, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey) + if err != nil { + return nil, nil, err + } + + return privKeyBytes, certBytes, nil +} diff --git a/branches/origin-master/cmd/suika-znc-import/main.go b/branches/origin-master/cmd/suika-znc-import/main.go new file mode 100644 index 0000000..78fefc6 --- /dev/null +++ b/branches/origin-master/cmd/suika-znc-import/main.go @@ -0,0 +1,474 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "net/url" + "os" + "strings" + "unicode" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +const usage = `usage: suika-znc-import [options...] + +Imports configuration from a ZNC file. Users and networks are merged if they +already exist in the suika database. ZNC settings overwrite existing suika +settings. + +Options: + + -help Show this help message + -config Path to suika config file + -user Limit import to username (may be specified multiple times) + -network Limit import to network (may be specified multiple times) +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + users := make(map[string]bool) + networks := make(map[string]bool) + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Var((*stringSetFlag)(&users), "user", "") + flag.Var((*stringSetFlag)(&networks), "network", "") + flag.Parse() + + zncPath := flag.Arg(0) + if zncPath == "" { + flag.Usage() + os.Exit(1) + } + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + ctx := context.Background() + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + f, err := os.Open(zncPath) + if err != nil { + log.Fatalf("failed to open ZNC configuration file: %v", err) + } + defer f.Close() + + zp := zncParser{bufio.NewReader(f), 1} + root, err := zp.sectionBody("", "") + if err != nil { + log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err) + } + + l, err := db.ListUsers(ctx) + if err != nil { + log.Fatalf("failed to list users in DB: %v", err) + } + existingUsers := make(map[string]*suika.User, len(l)) + for i, u := range l { + existingUsers[u.Username] = &l[i] + } + + usersCreated := 0 + usersImported := 0 + networksImported := 0 + channelsImported := 0 + root.ForEach("User", func(section *zncSection) { + username := section.Name + if len(users) > 0 && !users[username] { + return + } + usersImported++ + + u, ok := existingUsers[username] + if ok { + log.Printf("user %q: updating existing user", username) + } else { + // "!!" is an invalid crypt format, thus disables password auth + u = &suika.User{Username: username, Password: "!!"} + usersCreated++ + log.Printf("user %q: creating new user", username) + } + + u.Admin = section.Values.Get("Admin") == "true" + + if err := db.StoreUser(ctx, u); err != nil { + log.Fatalf("failed to store user %q: %v", username, err) + } + userID := u.ID + + l, err := db.ListNetworks(ctx, userID) + if err != nil { + log.Fatalf("failed to list networks for user %q: %v", username, err) + } + existingNetworks := make(map[string]*suika.Network, len(l)) + for i, n := range l { + existingNetworks[n.GetName()] = &l[i] + } + + nick := section.Values.Get("Nick") + realname := section.Values.Get("RealName") + ident := section.Values.Get("Ident") + + section.ForEach("Network", func(section *zncSection) { + netName := section.Name + if len(networks) > 0 && !networks[netName] { + return + } + networksImported++ + + logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName) + logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix) + + netNick := section.Values.Get("Nick") + if netNick == "" { + netNick = nick + } + netRealname := section.Values.Get("RealName") + if netRealname == "" { + netRealname = realname + } + netIdent := section.Values.Get("Ident") + if netIdent == "" { + netIdent = ident + } + + for _, name := range section.Values["LoadModule"] { + switch name { + case "sasl": + logger.Printf("warning: SASL credentials not imported") + case "nickserv": + logger.Printf("warning: NickServ credentials not imported") + case "perform": + logger.Printf("warning: \"perform\" plugin commands not imported") + } + } + + u, pass, err := importNetworkServer(section.Values.Get("Server")) + if err != nil { + logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err) + } + + n, ok := existingNetworks[netName] + if ok { + logger.Printf("updating existing network") + } else { + n = &suika.Network{Name: netName} + logger.Printf("creating new network") + } + + n.Addr = u.String() + n.Nick = netNick + n.Username = netIdent + n.Realname = netRealname + n.Pass = pass + n.Enabled = section.Values.Get("IRCConnectEnabled") != "false" + + if err := db.StoreNetwork(ctx, userID, n); err != nil { + logger.Fatalf("failed to store network: %v", err) + } + + l, err := db.ListChannels(ctx, n.ID) + if err != nil { + logger.Fatalf("failed to list channels: %v", err) + } + existingChannels := make(map[string]*suika.Channel, len(l)) + for i, ch := range l { + existingChannels[ch.Name] = &l[i] + } + + section.ForEach("Chan", func(section *zncSection) { + chName := section.Name + + if section.Values.Get("Disabled") == "true" { + logger.Printf("skipping import of disabled channel %q", chName) + return + } + + channelsImported++ + + ch, ok := existingChannels[chName] + if ok { + logger.Printf("channel %q: updating existing channel", chName) + } else { + ch = &suika.Channel{Name: chName} + logger.Printf("channel %q: creating new channel", chName) + } + + ch.Key = section.Values.Get("Key") + ch.Detached = section.Values.Get("Detached") == "true" + + if err := db.StoreChannel(ctx, n.ID, ch); err != nil { + logger.Printf("channel %q: failed to store channel: %v", chName, err) + } + }) + }) + }) + + if err := db.Close(); err != nil { + log.Printf("failed to close database: %v", err) + } + + if usersCreated > 0 { + log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password `") + } + + log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported) +} + +func importNetworkServer(s string) (u *url.URL, pass string, err error) { + parts := strings.Fields(s) + if len(parts) < 2 { + return nil, "", fmt.Errorf("expected space-separated host and port") + } + + scheme := "irc" + host := parts[0] + port := parts[1] + if strings.HasPrefix(port, "+") { + port = port[1:] + scheme = "ircs" + } + + if len(parts) > 2 { + pass = parts[2] + } + + u = &url.URL{ + Scheme: scheme, + Host: host + ":" + port, + } + return u, pass, nil +} + +type zncSection struct { + Type string + Name string + Values zncValues + Children []zncSection +} + +func (s *zncSection) ForEach(typ string, f func(*zncSection)) { + for _, section := range s.Children { + if section.Type == typ { + f(§ion) + } + } +} + +type zncValues map[string][]string + +func (zv zncValues) Get(k string) string { + if len(zv[k]) == 0 { + return "" + } + return zv[k][0] +} + +type zncParser struct { + br *bufio.Reader + line int +} + +func (zp *zncParser) readByte() (byte, error) { + b, err := zp.br.ReadByte() + if b == '\n' { + zp.line++ + } + return b, err +} + +func (zp *zncParser) readRune() (rune, int, error) { + r, n, err := zp.br.ReadRune() + if r == '\n' { + zp.line++ + } + return r, n, err +} + +func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) { + section := &zncSection{Type: typ, Name: name, Values: make(zncValues)} + +Loop: + for { + if err := zp.skipSpace(); err != nil { + return nil, err + } + + b, err := zp.br.Peek(2) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + switch b[0] { + case '<': + if b[1] == '/' { + break Loop + } else { + childType, childName, err := zp.sectionHeader() + if err != nil { + return nil, err + } + child, err := zp.sectionBody(childType, childName) + if err != nil { + return nil, err + } + if footerType, err := zp.sectionFooter(); err != nil { + return nil, err + } else if footerType != childType { + return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType) + } + section.Children = append(section.Children, *child) + } + case '/': + if b[1] == '/' { + if err := zp.skipComment(); err != nil { + return nil, err + } + break + } + fallthrough + default: + k, v, err := zp.keyValuePair() + if err != nil { + return nil, err + } + section.Values[k] = append(section.Values[k], v) + } + } + + return section, nil +} + +func (zp *zncParser) skipSpace() error { + for { + r, _, err := zp.readRune() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if !unicode.IsSpace(r) { + zp.br.UnreadRune() + return nil + } + } +} + +func (zp *zncParser) skipComment() error { + if err := zp.expectRune('/'); err != nil { + return err + } + if err := zp.expectRune('/'); err != nil { + return err + } + + for { + b, err := zp.readByte() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if b == '\n' { + return nil + } + } +} + +func (zp *zncParser) sectionHeader() (string, string, error) { + if err := zp.expectRune('<'); err != nil { + return "", "", err + } + typ, err := zp.readWord(' ') + if err != nil { + return "", "", err + } + name, err := zp.readWord('>') + return typ, name, err +} + +func (zp *zncParser) sectionFooter() (string, error) { + if err := zp.expectRune('<'); err != nil { + return "", err + } + if err := zp.expectRune('/'); err != nil { + return "", err + } + return zp.readWord('>') +} + +func (zp *zncParser) keyValuePair() (string, string, error) { + k, err := zp.readWord('=') + if err != nil { + return "", "", err + } + v, err := zp.readWord('\n') + return strings.TrimSpace(k), strings.TrimSpace(v), err +} + +func (zp *zncParser) expectRune(expected rune) error { + r, _, err := zp.readRune() + if err != nil { + return err + } else if r != expected { + return fmt.Errorf("expected %q, got %q", expected, r) + } + return nil +} + +func (zp *zncParser) readWord(delim byte) (string, error) { + var sb strings.Builder + for { + b, err := zp.readByte() + if err != nil { + return "", err + } + + if b == delim { + return sb.String(), nil + } + if b == '\n' { + return "", fmt.Errorf("expected %q before newline", delim) + } + + sb.WriteByte(b) + } +} + +type stringSetFlag map[string]bool + +func (v *stringSetFlag) String() string { + return fmt.Sprint(map[string]bool(*v)) +} + +func (v *stringSetFlag) Set(s string) error { + (*v)[s] = true + return nil +} diff --git a/branches/origin-master/cmd/suika/main.go b/branches/origin-master/cmd/suika/main.go new file mode 100644 index 0000000..9661859 --- /dev/null +++ b/branches/origin-master/cmd/suika/main.go @@ -0,0 +1,229 @@ +package main + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/url" + "os" + "os/signal" + "strings" + "sync/atomic" + "syscall" + "time" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +// TCP keep-alive interval for downstream TCP connections +const downstreamKeepAlive = 1 * time.Hour + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +func bumpOpenedFileLimit() error { + var rlimit syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err) + } + rlimit.Cur = rlimit.Max + if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err) + } + return nil +} + +var ( + configPath string + debug bool + + tlsCert atomic.Value // *tls.Certificate +) + +func loadConfig() (*config.Server, *suika.Config, error) { + var raw *config.Server + if configPath != "" { + var err error + raw, err = config.Load(configPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load config file: %v", err) + } + } else { + raw = config.Defaults() + } + + var motd string + if raw.MOTDPath != "" { + b, err := os.ReadFile(raw.MOTDPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load MOTD: %v", err) + } + motd = strings.TrimSuffix(string(b), "\n") + } + + if raw.TLS != nil { + cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err) + } + tlsCert.Store(&cert) + } + + cfg := &suika.Config{ + Hostname: raw.Hostname, + Title: raw.Title, + LogPath: raw.LogPath, + MaxUserNetworks: raw.MaxUserNetworks, + MultiUpstream: raw.MultiUpstream, + UpstreamUserIPs: raw.UpstreamUserIPs, + MOTD: motd, + } + return raw, cfg, nil +} + +func main() { + var listen []string + flag.Var((*stringSliceFlag)(&listen), "listen", "listening address") + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.BoolVar(&debug, "debug", false, "enable debug logging") + flag.Parse() + + cfg, serverCfg, err := loadConfig() + if err != nil { + log.Fatal(err) + } + + cfg.Listen = append(cfg.Listen, listen...) + if len(cfg.Listen) == 0 { + cfg.Listen = []string{":6667"} + } + + if err := bumpOpenedFileLimit(); err != nil { + log.Printf("failed to bump max number of opened files: %v", err) + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + var tlsCfg *tls.Config + if cfg.TLS != nil { + tlsCfg = &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return tlsCert.Load().(*tls.Certificate), nil + }, + } + } + + srv := suika.NewServer(db) + srv.SetConfig(serverCfg) + srv.Logger = suika.NewLogger(log.Writer(), debug) + + for _, listen := range cfg.Listen { + listen := listen // copy + listenURI := listen + if !strings.Contains(listenURI, ":/") { + // This is a raw domain name, make it an URL with an empty scheme + listenURI = "//" + listenURI + } + u, err := url.Parse(listenURI) + if err != nil { + log.Fatalf("failed to parse listen URI %q: %v", listen, err) + } + + switch u.Scheme { + case "ircs", "": + if tlsCfg == nil { + log.Fatalf("failed to listen on %q: missing TLS configuration", listen) + } + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6697" + } + ircsTLSCfg := tlsCfg.Clone() + ircsTLSCfg.NextProtos = []string{"irc"} + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + l, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start TLS listener on %q: %v", listen, err) + } + ln := tls.NewListener(l, ircsTLSCfg) + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "irc": + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6667" + } + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + ln, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "unix": + ln, err := net.Listen("unix", u.Path) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + default: + log.Fatalf("failed to listen on %q: unsupported scheme", listen) + } + + log.Printf("starting suika version %v\n", suika.FullVersion()) + log.Printf("server listening on %q", listen) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + if err := srv.Start(); err != nil { + log.Fatal(err) + } + + for sig := range sigCh { + switch sig { + case syscall.SIGHUP: + log.Print("reloading configuration") + _, serverCfg, err := loadConfig() + if err != nil { + log.Printf("failed to reloading configuration: %v", err) + } else { + srv.SetConfig(serverCfg) + } + case syscall.SIGINT, syscall.SIGTERM: + log.Print("shutting down server") + srv.Shutdown() + return + } + } +} diff --git a/branches/origin-master/cmd/suikadb/main.go b/branches/origin-master/cmd/suikadb/main.go new file mode 100644 index 0000000..4d54481 --- /dev/null +++ b/branches/origin-master/cmd/suikadb/main.go @@ -0,0 +1,150 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "os" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" +) + +const usage = `usage: suikadb [-config path] [options...] + + create-user [-admin] Create a new user + change-password Change password for a user + help Show this help message +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Parse() + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + ctx := context.Background() + + switch cmd := flag.Arg(0); cmd { + case "create-user": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + fs := flag.NewFlagSet("", flag.ExitOnError) + admin := fs.Bool("admin", false, "make the new user admin") + fs.Parse(flag.Args()[2:]) + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user := suika.User{ + Username: username, + Password: string(hashed), + Admin: *admin, + } + if err := db.StoreUser(ctx, &user); err != nil { + log.Fatalf("failed to create user: %v", err) + } + case "change-password": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + user, err := db.GetUser(ctx, username) + if err != nil { + log.Fatalf("failed to get user: %v", err) + } + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user.Password = string(hashed) + if err := db.StoreUser(ctx, user); err != nil { + log.Fatalf("failed to update password: %v", err) + } + case "version": + fmt.Printf("%v\n", suika.FullVersion()) + default: + flag.Usage() + if cmd != "help" { + os.Exit(1) + } + } +} + +func readPassword() ([]byte, error) { + var password []byte + var err error + fd := int(os.Stdin.Fd()) + + if term.IsTerminal(fd) { + fmt.Printf("Password: ") + password, err = term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + return nil, err + } + fmt.Printf("\n") + } else { + fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n") + // TODO: the buffering messes up repeated calls to readPassword + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, err + } + return nil, io.ErrUnexpectedEOF + } + password = scanner.Bytes() + + if len(password) == 0 { + return nil, fmt.Errorf("zero length password") + } + } + + return password, nil +} diff --git a/branches/origin-master/config.in b/branches/origin-master/config.in new file mode 100644 index 0000000..0b1ed4d --- /dev/null +++ b/branches/origin-master/config.in @@ -0,0 +1,2 @@ +db sqlite3 /var/lib/suika/main.db +log fs /var/lib/suika/logs/ diff --git a/branches/origin-master/config/config.go b/branches/origin-master/config/config.go new file mode 100644 index 0000000..94d4584 --- /dev/null +++ b/branches/origin-master/config/config.go @@ -0,0 +1,142 @@ +package config + +import ( + "fmt" + "net" + "os" + "strconv" + + "git.sr.ht/~emersion/go-scfg" +) + +type TLS struct { + CertPath, KeyPath string +} + +type Server struct { + Listen []string + TLS *TLS + Hostname string + Title string + MOTDPath string + + SQLDriver string + SQLSource string + LogPath string + + MaxUserNetworks int + MultiUpstream bool + UpstreamUserIPs []*net.IPNet +} + +func Defaults() *Server { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + return &Server{ + Hostname: hostname, + SQLDriver: "sqlite3", + SQLSource: "suika.db", + MaxUserNetworks: -1, + MultiUpstream: true, + } +} + +func Load(path string) (*Server, error) { + cfg, err := scfg.Load(path) + if err != nil { + return nil, err + } + return parse(cfg) +} + +func parse(cfg scfg.Block) (*Server, error) { + srv := Defaults() + for _, d := range cfg { + switch d.Name { + case "listen": + var uri string + if err := d.ParseParams(&uri); err != nil { + return nil, err + } + srv.Listen = append(srv.Listen, uri) + case "hostname": + if err := d.ParseParams(&srv.Hostname); err != nil { + return nil, err + } + case "title": + if err := d.ParseParams(&srv.Title); err != nil { + return nil, err + } + case "motd": + if err := d.ParseParams(&srv.MOTDPath); err != nil { + return nil, err + } + case "tls": + tls := &TLS{} + if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil { + return nil, err + } + srv.TLS = tls + case "db": + if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil { + return nil, err + } + case "log": + var driver string + if err := d.ParseParams(&driver, &srv.LogPath); err != nil { + return nil, err + } + if driver != "fs" { + return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver) + } + case "max-user-networks": + var max string + if err := d.ParseParams(&max); err != nil { + return nil, err + } + var err error + if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + case "multi-upstream-mode": + var str string + if err := d.ParseParams(&str); err != nil { + return nil, err + } + v, err := strconv.ParseBool(str) + if err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + srv.MultiUpstream = v + case "upstream-user-ip": + if len(srv.UpstreamUserIPs) > 0 { + return nil, fmt.Errorf("directive %q: can only be specified once", d.Name) + } + var hasIPv4, hasIPv6 bool + for _, s := range d.Params { + _, n, err := net.ParseCIDR(s) + if err != nil { + return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err) + } + if n.IP.To4() == nil { + if hasIPv6 { + return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name) + } + hasIPv6 = true + } else { + if hasIPv4 { + return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name) + } + hasIPv4 = true + } + srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) + } + default: + return nil, fmt.Errorf("unknown directive %q", d.Name) + } + } + + return srv, nil +} diff --git a/branches/origin-master/conn.go b/branches/origin-master/conn.go new file mode 100644 index 0000000..bb95e25 --- /dev/null +++ b/branches/origin-master/conn.go @@ -0,0 +1,180 @@ +package suika + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "golang.org/x/time/rate" + "gopkg.in/irc.v3" +) + +// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on +// reading and writing IRC messages. +type ircConn interface { + ReadMessage() (*irc.Message, error) + WriteMessage(*irc.Message) error + Close() error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error + RemoteAddr() net.Addr + LocalAddr() net.Addr +} + +func newNetIRCConn(c net.Conn) ircConn { + type netConn net.Conn + return struct { + *irc.Conn + netConn + }{irc.NewConn(c), c} +} + +type connOptions struct { + Logger Logger + RateLimitDelay time.Duration + RateLimitBurst int +} + +type conn struct { + conn ircConn + srv *Server + logger Logger + + lock sync.Mutex + outgoing chan<- *irc.Message + closed bool + closedCh chan struct{} +} + +func newConn(srv *Server, ic ircConn, options *connOptions) *conn { + outgoing := make(chan *irc.Message, 64) + c := &conn{ + conn: ic, + srv: srv, + outgoing: outgoing, + logger: options.Logger, + closedCh: make(chan struct{}), + } + + go func() { + ctx, cancel := c.NewContext(context.Background()) + defer cancel() + + rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst) + for msg := range outgoing { + if err := rl.Wait(ctx); err != nil { + break + } + + c.logger.Debugf("sent: %v", msg) + c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := c.conn.WriteMessage(msg); err != nil { + c.logger.Printf("failed to write message: %v", err) + break + } + } + if err := c.conn.Close(); err != nil && !isErrClosed(err) { + c.logger.Printf("failed to close connection: %v", err) + } else { + c.logger.Debugf("connection closed") + } + // Drain the outgoing channel to prevent SendMessage from blocking + for range outgoing { + // This space is intentionally left blank + } + }() + + c.logger.Debugf("new connection") + return c +} + +func (c *conn) isClosed() bool { + c.lock.Lock() + defer c.lock.Unlock() + return c.closed +} + +// Close closes the connection. It is safe to call from any goroutine. +func (c *conn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return fmt.Errorf("connection already closed") + } + + err := c.conn.Close() + c.closed = true + close(c.outgoing) + close(c.closedCh) + return err +} + +func (c *conn) ReadMessage() (*irc.Message, error) { + msg, err := c.conn.ReadMessage() + if isErrClosed(err) { + return nil, io.EOF + } else if err != nil { + return nil, err + } + + c.logger.Debugf("received: %v", msg) + return msg, nil +} + +// SendMessage queues a new outgoing message. It is safe to call from any +// goroutine. +// +// If the connection is closed before the message is sent, SendMessage silently +// drops the message. +func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return + } + + select { + case c.outgoing <- msg: + // Success + case <-ctx.Done(): + c.logger.Printf("failed to send message: %v", ctx.Err()) + } +} + +func (c *conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// NewContext returns a copy of the parent context with a new Done channel. The +// returned context's Done channel is closed when the connection is closed, +// when the returned cancel function is called, or when the parent context's +// Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(parent) + + go func() { + defer cancel() + + select { + case <-ctx.Done(): + // The parent context has been cancelled, or the caller has called + // cancel() + case <-c.closedCh: + // The connection has been closed + } + }() + + return ctx, cancel +} diff --git a/branches/origin-master/contrib/casemap-logs.sh b/branches/origin-master/contrib/casemap-logs.sh new file mode 100644 index 0000000..709d581 --- /dev/null +++ b/branches/origin-master/contrib/casemap-logs.sh @@ -0,0 +1,42 @@ +#!/bin/sh -eu + +# Converts a log dir to its case-mapped form. +# +# suika needs to be stopped for this script to work properly. The script may +# re-order messages that happened within the same second interval if merging +# two daily log files is necessary. +# +# usage: casemap-logs.sh + +root="$1" + +for net_dir in "$root"/*/*; do + for chan in $(ls "$net_dir"); do + cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')" + if [ "$chan" = "$cm_chan" ]; then + continue + fi + + if ! [ -d "$net_dir/$cm_chan" ]; then + echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + mv "$net_dir/$chan" "$net_dir/$cm_chan" + continue + fi + + echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + for day in $(ls "$net_dir/$chan"); do + if ! [ -e "$net_dir/$cm_chan/$day" ]; then + echo >&2 " Moving log file: '$day'" + mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" + continue + fi + + echo >&2 " Merging log file: '$day'" + sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new" + mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day" + rm "$net_dir/$chan/$day" + done + + rmdir "$net_dir/$chan" + done +done diff --git a/branches/origin-master/contrib/clients.md b/branches/origin-master/contrib/clients.md new file mode 100644 index 0000000..1f8881d --- /dev/null +++ b/branches/origin-master/contrib/clients.md @@ -0,0 +1,100 @@ +# Clients + +This page describes how to configure IRC clients to better integrate with soju. + +Also see the [IRCv3 support tables] for a more general list of clients. + +# catgirl + +catgirl doesn't properly implement cap-3.2, so many capabilities will be +disabled. catgirl developers have publicly stated that supporting bouncers such +as soju is a non-goal. + +# [Emacs] + +There are two clients provided with Emacs. They require some setup to work +properly. + +## Erc + +You need to explicitly set the username, which is the defcustom +`erc-email-userid`. + +```elisp +(setq erc-email-userid "/irc.libera.chat") ;; Example with Libera.Chat +(defun run-erc () + (interactive) + (erc-tls :server "" + :port 6697 + :nick "" + :password "")) +``` + +Then run `M-x run-erc`. + +## Rcirc + +The only thing needed here is the general config: + +```elisp +(setq rcirc-server-alist + '(("" + :port 6697 + :encryption tls + :nick "" + :user-name "/irc.libera.chat" ;; Example with Libera.Chat + :password ""))) +``` + +Then run `M-x irc`. + +# [gamja] + +gamja has been designed together with soju, so should have excellent +integration. gamja supports many IRCv3 features including chat history. +gamja also provides UI to manage soju networks via the +`soju.im/bouncer-networks` extension. + +# [goguma] + +Much like gamja, goguma has been designed together with soju, so should have +excellent integration. goguma supports many IRCv3 features including chat +history. goguma should seamlessly connect to all networks configured in soju via +the `soju.im/bouncer-networks` extension. + +# [Hexchat] + +Hexchat has support for a small set of IRCv3 capabilities. To prevent +automatically reconnecting to channels parted from soju, and prevent buffering +outgoing messages: + + /set irc_reconnect_rejoin off + /set net_throttle off + +# [senpai] + +senpai is being developed with soju in mind, so should have excellent +integration. senpai supports many IRCv3 features including chat history. + +# [Weechat] + +A [Weechat script] is available to provide better integration with soju. +The script will automatically connect to all of your networks once a +single connection to soju is set up in Weechat. + +On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them: + + /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names + /save + /reconnect -all + +See `/help cap` for more information. + +[IRCv3 support tables]: https://ircv3.net/software/clients +[gamja]: https://sr.ht/~emersion/gamja/ +[goguma]: https://sr.ht/~emersion/goguma/ +[senpai]: https://sr.ht/~taiite/senpai/ +[Weechat]: https://weechat.org/ +[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py +[Hexchat]: https://hexchat.github.io/ +[Emacs]: https://www.gnu.org/software/emacs/ diff --git a/branches/origin-master/db.go b/branches/origin-master/db.go new file mode 100644 index 0000000..95ac964 --- /dev/null +++ b/branches/origin-master/db.go @@ -0,0 +1,184 @@ +package suika + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" +) + +type Database interface { + Close() error + Stats(ctx context.Context) (*DatabaseStats, error) + + ListUsers(ctx context.Context) ([]User, error) + GetUser(ctx context.Context, username string) (*User, error) + StoreUser(ctx context.Context, user *User) error + DeleteUser(ctx context.Context, id int64) error + + ListNetworks(ctx context.Context, userID int64) ([]Network, error) + StoreNetwork(ctx context.Context, userID int64, network *Network) error + DeleteNetwork(ctx context.Context, id int64) error + ListChannels(ctx context.Context, networkID int64) ([]Channel, error) + StoreChannel(ctx context.Context, networKID int64, ch *Channel) error + DeleteChannel(ctx context.Context, id int64) error + + ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) + StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error + + GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) + StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error +} + +func OpenDB(driver, source string) (Database, error) { + switch driver { + case "sqlite3": + return OpenSqliteDB(source) + case "postgres": + return OpenPostgresDB(source) + default: + return nil, fmt.Errorf("unsupported database driver: %q", driver) + } +} + +type DatabaseStats struct { + Users int64 + Networks int64 + Channels int64 +} + +type User struct { + ID int64 + Username string + Password string // hashed + Realname string + Admin bool +} + +type SASL struct { + Mechanism string + + Plain struct { + Username string + Password string + } + + // TLS client certificate authentication. + External struct { + // X.509 certificate in DER form. + CertBlob []byte + // PKCS#8 private key in DER form. + PrivKeyBlob []byte + } +} + +type Network struct { + ID int64 + Name string + Addr string + Nick string + Username string + Realname string + Pass string + ConnectCommands []string + SASL SASL + Enabled bool +} + +func (net *Network) GetName() string { + if net.Name != "" { + return net.Name + } + return net.Addr +} + +func (net *Network) URL() (*url.URL, error) { + s := net.Addr + if !strings.Contains(s, "://") { + // This is a raw domain name, make it an URL with the default scheme + s = "ircs://" + s + } + + u, err := url.Parse(s) + if err != nil { + return nil, fmt.Errorf("failed to parse upstream server URL: %v", err) + } + + return u, nil +} + +func GetNick(user *User, net *Network) string { + if net.Nick != "" { + return net.Nick + } + return user.Username +} + +func GetUsername(user *User, net *Network) string { + if net.Username != "" { + return net.Username + } + return GetNick(user, net) +} + +func GetRealname(user *User, net *Network) string { + if net.Realname != "" { + return net.Realname + } + if user.Realname != "" { + return user.Realname + } + return GetNick(user, net) +} + +type MessageFilter int + +const ( + // TODO: use customizable user defaults for FilterDefault + FilterDefault MessageFilter = iota + FilterNone + FilterHighlight + FilterMessage +) + +func parseFilter(filter string) (MessageFilter, error) { + switch filter { + case "default": + return FilterDefault, nil + case "none": + return FilterNone, nil + case "highlight": + return FilterHighlight, nil + case "message": + return FilterMessage, nil + } + return 0, fmt.Errorf("unknown filter: %q", filter) +} + +type Channel struct { + ID int64 + Name string + Key string + + Detached bool + DetachedInternalMsgID string + + RelayDetached MessageFilter + ReattachOn MessageFilter + DetachAfter time.Duration + DetachOn MessageFilter +} + +type DeliveryReceipt struct { + ID int64 + Target string // channel or nick + Client string + InternalMsgID string +} + +type ReadReceipt struct { + ID int64 + Target string // channel or nick + Timestamp time.Time +} diff --git a/branches/origin-master/db_postgres.go b/branches/origin-master/db_postgres.go new file mode 100644 index 0000000..0849338 --- /dev/null +++ b/branches/origin-master/db_postgres.go @@ -0,0 +1,494 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "errors" + "fmt" + "math" + "strings" + "time" + + _ "github.com/lib/pq" +) + +const postgresQueryTimeout = 5 * time.Second + +const postgresConfigSchema = ` +CREATE TABLE IF NOT EXISTS "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); +` +//go:embed suika_psql_schema.sql +var postgresSchema string + +var postgresMigrations = []string{ + "", // migration #0 is reserved for schema initialization + `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`, + ` + CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + ALTER TABLE "Network" + ALTER COLUMN sasl_mechanism + TYPE sasl_mechanism + USING sasl_mechanism::sasl_mechanism; + `, + ` + CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) + ); + `, +} + +type PostgresDB struct { + db *sql.DB +} + +func OpenPostgresDB(source string) (Database, error) { + sqlPostgresDB, err := sql.Open("postgres", source) + if err != nil { + return nil, err + } + + db := &PostgresDB{db: sqlPostgresDB} + if err := db.upgrade(); err != nil { + sqlPostgresDB.Close() + return nil, err + } + + return db, nil +} + +func (db *PostgresDB) upgrade() error { + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec(postgresConfigSchema); err != nil { + return fmt.Errorf("failed to create Config table: %s", err) + } + + var version int + err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to query schema version: %s", err) + } + + if version == len(postgresMigrations) { + return nil + } + if version > len(postgresMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version) + } + + if version == 0 { + if _, err := tx.Exec(postgresSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %s", err) + } + } else { + for i := version; i < len(postgresMigrations); i++ { + if _, err := tx.Exec(postgresMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1) + ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations)) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *PostgresDB) Close() error { + return db.db.Close() +} + +func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM "User") AS users, + (SELECT COUNT(*) FROM "Network") AS networks, + (SELECT COUNT(*) FROM "Channel") AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT id, username, password, admin, realname FROM "User"`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + `SELECT id, password, admin, realname FROM "User" WHERE username = $1`, + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + password := toNullString(user.Password) + realname := toNullString(user.Realname) + + var err error + if user.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "User" (username, password, admin, realname) + VALUES ($1, $2, $3, $4) + RETURNING id`, + user.Username, password, user.Admin, realname).Scan(&user.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "User" + SET password = $1, admin = $2, realname = $3 + WHERE id = $4`, + password, user.Admin, realname, user.ID) + } + return err +} + +func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, + sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled + FROM "Network" + WHERE "user" = $1`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + netName := toNullString(network.Name) + nick := toNullString(network.Nick) + netUsername := toNullString(network.Username) + realname := toNullString(network.Realname) + pass := toNullString(network.Pass) + connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + var err error + if network.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, + sasl_external_key, enabled) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id`, + userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Network" + SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, + connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, + sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13, + enabled = $14 + WHERE id = $1`, + network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled) + } + return err +} + +func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, + detach_on + FROM "Channel" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + key := toNullString(ch.Key) + detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds())) + + var err error + if ch.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, + detach_after, detach_on) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING id`, + networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Channel" + SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5, + relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9 + WHERE id = $1`, + ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn) + } + return err +} + +func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM "DeliveryReceipt" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, + `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`, + networkID, client) + if err != nil { + return err + } + + stmt, err := tx.PrepareContext(ctx, ` + INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid) + VALUES ($1, $2, $3, $4) + RETURNING id`) + if err != nil { + return err + } + defer stmt.Close() + + for i := range receipts { + rcpt := &receipts[i] + err := stmt. + QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID). + Scan(&rcpt.ID) + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, + `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`, + networkID, name) + if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return receipt, nil +} + +func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE "ReadReceipt" + SET timestamp = $1 + WHERE id = $2`, + receipt.Timestamp, receipt.ID) + } else { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "ReadReceipt" (network, target, timestamp) + VALUES ($1, $2, $3) + RETURNING id`, + networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID) + } + return err +} diff --git a/branches/origin-master/db_postgres_test.go b/branches/origin-master/db_postgres_test.go new file mode 100644 index 0000000..7692f31 --- /dev/null +++ b/branches/origin-master/db_postgres_test.go @@ -0,0 +1,104 @@ +package suika + +import ( + "database/sql" + "os" + "testing" +) + +// PostgreSQL version 0 schema. DO NOT EDIT. +const postgresV0Schema = ` +CREATE TABLE "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); + +INSERT INTO "Config" (id, version) VALUES (1, 1); + +CREATE TABLE "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TABLE "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA DEFAULT NULL, + sasl_external_key BYTEA DEFAULT NULL, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); + +CREATE TABLE "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); + +CREATE TABLE "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +` + +func openTempPostgresDB(t *testing.T) *sql.DB { + source, ok := os.LookupEnv("SOJU_TEST_POSTGRES") + if !ok { + t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests") + } + + db, err := sql.Open("postgres", source) + if err != nil { + t.Fatalf("failed to connect to PostgreSQL: %v", err) + } + + // Store all tables in a temporary schema which will be dropped when the + // connection to PostgreSQL is closed. + db.SetMaxOpenConns(1) + if _, err := db.Exec("SET search_path TO pg_temp"); err != nil { + t.Fatalf("failed to set PostgreSQL search_path: %v", err) + } + + return db +} + +func TestPostgresMigrations(t *testing.T) { + sqlDB := openTempPostgresDB(t) + if _, err := sqlDB.Exec(postgresV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &PostgresDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("PostgresDB.Upgrade() failed: %v", err) + } +} diff --git a/branches/origin-master/db_sqlite.go b/branches/origin-master/db_sqlite.go new file mode 100644 index 0000000..7ebebb0 --- /dev/null +++ b/branches/origin-master/db_sqlite.go @@ -0,0 +1,755 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "math" + "strings" + "sync" + "time" + + _ "modernc.org/sqlite" +) + +const sqliteQueryTimeout = 5 * time.Second + +//go:embed suika_sqlite_schema.sql +var sqliteSchema string + +var sqliteMigrations = []string{ + "", // migration #0 is reserved for schema initialization + "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)", + "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL", + "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL", + "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0", + ` + CREATE TABLE IF NOT EXISTS UserNew ( + id INTEGER PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin INTEGER NOT NULL DEFAULT 0 + ); + INSERT INTO UserNew SELECT rowid, username, password, admin FROM User; + DROP TABLE User; + ALTER TABLE UserNew RENAME TO User; + `, + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user INTEGER NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BLOB DEFAULT NULL, + sasl_external_key BLOB DEFAULT NULL, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT Network.id, name, User.id as user, addr, nick, + Network.username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key + FROM Network + JOIN User ON Network.user = User.username; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0; + `, + ` + CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target VARCHAR(255) NOT NULL, + client VARCHAR(255), + internal_msgid VARCHAR(255) NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) + ); + `, + "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)", + "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", + "ALTER TABLE User ADD COLUMN realname VARCHAR(255)", + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT id, name, user, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, + enabled + FROM Network; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) + ); + `, +} + +type SqliteDB struct { + lock sync.RWMutex + db *sql.DB +} + +func OpenSqliteDB(source string) (Database, error) { + sqlSqliteDB, err := sql.Open("sqlite", source) + if err != nil { + return nil, err + } + + db := &SqliteDB{db: sqlSqliteDB} + if err := db.upgrade(); err != nil { + sqlSqliteDB.Close() + return nil, err + } + + return db, nil +} + +func (db *SqliteDB) Close() error { + db.lock.Lock() + defer db.lock.Unlock() + return db.db.Close() +} + +func (db *SqliteDB) upgrade() error { + db.lock.Lock() + defer db.lock.Unlock() + + var version int + if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + return fmt.Errorf("failed to query schema version: %v", err) + } + + if version == len(sqliteMigrations) { + return nil + } else if version > len(sqliteMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version) + } + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if version == 0 { + if _, err := tx.Exec(sqliteSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %v", err) + } + } else { + for i := version; i < len(sqliteMigrations); i++ { + if _, err := tx.Exec(sqliteMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + // For some reason prepared statements don't work here + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations))) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM User) AS users, + (SELECT COUNT(*) FROM Network) AS networks, + (SELECT COUNT(*) FROM Channel) AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func toNullString(s string) sql.NullString { + return sql.NullString{ + String: s, + Valid: s != "", + } +} + +func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + "SELECT id, username, password, admin, realname FROM User") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + "SELECT id, password, admin, realname FROM User WHERE username = ?", + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("username", user.Username), + sql.Named("password", toNullString(user.Password)), + sql.Named("admin", user.Admin), + sql.Named("realname", toNullString(user.Realname)), + } + + var err error + if user.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE User SET password = :password, admin = :admin, + realname = :realname WHERE username = :username`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + User(username, password, admin, realname) + VALUES (:username, :password, :admin, :realname)`, + args...) + if err != nil { + return err + } + user.ID, err = res.LastInsertId() + } + + return err +} + +func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt + WHERE id IN ( + SELECT DeliveryReceipt.id + FROM DeliveryReceipt + JOIN Network ON DeliveryReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt + WHERE id IN ( + SELECT ReadReceipt.id + FROM ReadReceipt + JOIN Network ON ReadReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM Channel + WHERE id IN ( + SELECT Channel.id + FROM Channel + JOIN Network ON Channel.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key, enabled + FROM Network + WHERE user = ?`, + userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + args := []interface{}{ + sql.Named("name", toNullString(network.Name)), + sql.Named("addr", network.Addr), + sql.Named("nick", toNullString(network.Nick)), + sql.Named("username", toNullString(network.Username)), + sql.Named("realname", toNullString(network.Realname)), + sql.Named("pass", toNullString(network.Pass)), + sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))), + sql.Named("sasl_mechanism", saslMechanism), + sql.Named("sasl_plain_username", saslPlainUsername), + sql.Named("sasl_plain_password", saslPlainPassword), + sql.Named("sasl_external_cert", network.SASL.External.CertBlob), + sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob), + sql.Named("enabled", network.Enabled), + + sql.Named("id", network.ID), // only for UPDATE + sql.Named("user", userID), // only for INSERT + } + + var err error + if network.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE Network + SET name = :name, addr = :addr, nick = :nick, username = :username, + realname = :realname, pass = :pass, connect_commands = :connect_commands, + sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password, + sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key, + enabled = :enabled + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO Network(user, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, enabled) + VALUES (:user, :name, :addr, :nick, :username, :realname, :pass, + :connect_commands, :sasl_mechanism, :sasl_plain_username, + :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`, + args...) + if err != nil { + return err + } + network.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, `SELECT + id, name, key, detached, detached_internal_msgid, + relay_detached, reattach_on, detach_after, detach_on + FROM Channel + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("network", networkID), + sql.Named("name", ch.Name), + sql.Named("key", toNullString(ch.Key)), + sql.Named("detached", ch.Detached), + sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)), + sql.Named("relay_detached", ch.RelayDetached), + sql.Named("reattach_on", ch.ReattachOn), + sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))), + sql.Named("detach_on", ch.DetachOn), + + sql.Named("id", ch.ID), // only for UPDATE + } + + var err error + if ch.ID != 0 { + _, err = db.db.ExecContext(ctx, `UPDATE Channel + SET network = :network, name = :name, key = :key, detached = :detached, + detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached, + reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) + VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...) + if err != nil { + return err + } + ch.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id) + return err +} + +func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM DeliveryReceipt + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + var client sql.NullString + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + rcpt.Client = client.String + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?", + networkID, toNullString(client)) + if err != nil { + return err + } + + for i := range receipts { + rcpt := &receipts[i] + + res, err := tx.ExecContext(ctx, ` + INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) + VALUES (:network, :target, :client, :internal_msgid)`, + sql.Named("network", networkID), + sql.Named("target", rcpt.Target), + sql.Named("client", toNullString(client)), + sql.Named("internal_msgid", rcpt.InternalMsgID)) + if err != nil { + return err + } + rcpt.ID, err = res.LastInsertId() + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, ` + SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`, + sql.Named("network", networkID), + sql.Named("target", name), + ) + var timestamp string + if err := row.Scan(&receipt.ID, ×tamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + if t, err := time.Parse(serverTimeLayout, timestamp); err != nil { + return nil, err + } else { + receipt.Timestamp = t + } + return receipt, nil +} + +func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("id", receipt.ID), + sql.Named("timestamp", formatServerTime(receipt.Timestamp)), + sql.Named("network", networkID), + sql.Named("target", receipt.Target), + } + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + ReadReceipt(network, target, timestamp) + VALUES (:network, :target, :timestamp)`, + args...) + if err != nil { + return err + } + receipt.ID, err = res.LastInsertId() + } + + return err +} diff --git a/branches/origin-master/db_sqlite_test.go b/branches/origin-master/db_sqlite_test.go new file mode 100644 index 0000000..a7c9794 --- /dev/null +++ b/branches/origin-master/db_sqlite_test.go @@ -0,0 +1,59 @@ +package suika + +import ( + "database/sql" + "testing" +) + +// SQLite version 0 schema. DO NOT EDIT. +const sqliteV0Schema = ` +CREATE TABLE User ( + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255) +); + +CREATE TABLE Network ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user VARCHAR(255) NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); + +CREATE TABLE Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +PRAGMA user_version = 1; +` + +func TestSqliteMigrations(t *testing.T) { + sqlDB, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + + if _, err := sqlDB.Exec(sqliteV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &SqliteDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("SqliteDB.Upgrade() failed: %v", err) + } +} diff --git a/branches/origin-master/doc.go b/branches/origin-master/doc.go new file mode 100644 index 0000000..3f7af51 --- /dev/null +++ b/branches/origin-master/doc.go @@ -0,0 +1,20 @@ +// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go. +// +// # Copyright (C) 2020 The soju Contributors +// # Copyright (C) 2023-present Izuru Yakumo et al. +// +// suika is covered by the AGPLv3 license: +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . +package suika diff --git a/branches/origin-master/doc/suika-config.5 b/branches/origin-master/doc/suika-config.5 new file mode 100644 index 0000000..ac0e9ef --- /dev/null +++ b/branches/origin-master/doc/suika-config.5 @@ -0,0 +1,70 @@ +.Dd $Mdocdate$ +.Dt SUIKA-CONFIG 5 +.Os +.Sh NAME +.Nm suika-config +.Nd Configuration file for the IRC bouncer +.Sh SYNOPSIS +.Bk -words +listen ircs:// +.Pp +tls cert.pem key.pem +.Pp +hostname example.org +.Ek +.Sh DESCRIPTION +This document describes the format of the configuration +file used by +.Xr suika 1 +.Sh OPTIONS +.Bl -tag -width Ds +.It listen Ar uri +With this you can control on what +ports/protocols +.Xr suika 1 +listens on, it supports +irc (cleartext IRC), ircs (IRC with TLS), and unix +(IRC over Unix domain sockets) +.It hostname Ar hostname +Server hostname, if unset, the system one is used. +.It title Ar title +Server title, this will be sent as the ISUPPORT NETWORK value when +clients don't select a specific network. +.It tls Ar cert Ar key +Enable TLS support, the certificate and key files must be +PEM-encoded. +.It db Ar driver Ar path +Set the database driver for user, network and channel storage. +By default a SQLite 3 database is opened in +.Pa ./suika.db +Supported drivers are sqlite and postgres, the former +expects a path to the database file, and the latter +a space-separated list of key=value parameters, +e.g. host=localhost dbname=suika +.It log fs Ar path +Path to the bouncer logs directory, or empty to disable +logging. +By default, logging is disabled. +.It max-user-networks Ar limit +Maximum number of networks per user, by default +there is no limit. +.It motd Ar path +Path to the MOTD file, its contents are sent to clients +which aren't bound to a particular network. +By default, no MOTD is sent. +.It multi-upstream-mode Ar bool +Globally enable or disable multi-upstream mode. +By default, it is enabled. +.It upstream-user-ip Ar cidr +Enable per-user-IP addresses. +One IPv4 and/or one IPv6 range can be specified in CIDR notation. +One IP address per range will be assigned to each user as the +source address when connecting to an upstream network. +This can be useful to avoid having the whole bouncer banned from +an upstream network because of one malicious user. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin-master/doc/suika-znc-import.1 b/branches/origin-master/doc/suika-znc-import.1 new file mode 100644 index 0000000..7acc618 --- /dev/null +++ b/branches/origin-master/doc/suika-znc-import.1 @@ -0,0 +1,34 @@ +.Dd $Mdocdate$ +.Dt SUIKA-ZNC-IMPORT 1 +.Os +.Sh NAME +.Nm suika-znc-import +.Nd Migration utility for moving from ZNC +.Sh SYNOPSIS +.Nm +.Op Fl config Ar suika config file +.Op Fl user Ar username +.Op Fl network Ar name +.Sh DESCRIPTION +Imports configuration from a ZNC file. +Users and networks are merged if they already exist in the +.Xr suika 1 +database. +ZNC settings overwrite existing +.Xr suika 1 +settings +.Sh OPTIONS +.Bl -tag -width Ds +.It config Ar suika config file +Path to +.Xr suika-config 5 +.It user Ar username +Limit import to username, may be specified multiple times. +It network Ar name +Limit import to network, may be specified multiple times. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin-master/doc/suika.1 b/branches/origin-master/doc/suika.1 new file mode 100644 index 0000000..afe6607 --- /dev/null +++ b/branches/origin-master/doc/suika.1 @@ -0,0 +1,68 @@ +.Dd $Mdocdate$ +.Dt SUIKA 1 +.Os +.Sh NAME +.Nm suika +.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project +.Sh SYNOPSIS +.Nm +.Op Fl config Ar path +.Op Fl debug +.Op Fl listen Ar uri +.Sh DESCRIPTION +.Nm +is an user-friendly IRC bouncer, it connects to upstream +IRC servers on behalf of the user to provide extra features. +.Bl -tag -width 6n +.It Multiple separate users sharing the same bouncer +.It Clients connecting to multiple upstream servers (via a single connection) +.It Sending the backlog with per-client buffers +.El +.Pp +When joining a channel, the channel will be saved +and automatically joined on the next connection. +When registering or authenticating with NickServ, the credentials will be saved +and automatically used on the next connection if the server supports SASL. +When parting a channel with the reason "detach", the channel will be +detached instead of being left. +When all clients are disconnected from the bouncer, +the user is automatically marked as away. +.Pp +.Nm +supports two connection modes: +.Bl -tag -width 6n +.It Single upstream mode +One downstream connection maps to one upstream connection +.Pp +To enable this mode, connect to the bouncer +with the username "/". +.Pp +If the bouncer isn't connected to the upstream server, +it will get automatically added. +.Pp +Then channels can be joined and parted as if +you were directly connected to the upstream server. +.It Multiple upstream mode +One downstream connection maps to multiple upstream connections. +Channels and nicks are suffixed with the network name. +To join a channel, you need to use the suffix too: /join #channel/network. +Same applies to messages sent to users. +.El +.Pp +For per-client history to work, clients need to indicate their name. +This can be done by adding a "@" suffix to the username. +.Pp +.Nm +will reload the configuration file, the TLS certificate/key and +the MOTD file when it receives the HUP signal. +The configuration options listen, db and log cannot be reloaded. +.Pp +Administrators can broadcast a message to all bouncer users via +/notice $ , or via /notice $ in multi-upstream mode. +All currently connected bouncer users will receive the message +from the special BouncerServ service. +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin-master/doc/suikadb.1 b/branches/origin-master/doc/suikadb.1 new file mode 100644 index 0000000..5ff4769 --- /dev/null +++ b/branches/origin-master/doc/suikadb.1 @@ -0,0 +1,16 @@ +.Dd $Mdocdate$ +.Dt SUIKADB 1 +.Os +.Sh NAME +.Nm suikadb +.Nd Basic user manipulation for +.Xr suika 1 +.Sh SYNOPSIS +.Nm +.Op create-user +.Op change-password +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin-master/downstream.go b/branches/origin-master/downstream.go new file mode 100644 index 0000000..3b9adac --- /dev/null +++ b/branches/origin-master/downstream.go @@ -0,0 +1,3047 @@ +package suika + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +type ircError struct { + Message *irc.Message +} + +func (err ircError) Error() string { + return err.Message.String() +} + +func newUnknownCommandError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{ + "*", + cmd, + "Unknown command", + }, + }} +} + +func newNeedMoreParamsError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_NEEDMOREPARAMS, + Params: []string{ + "*", + cmd, + "Not enough parameters", + }, + }} +} + +func newChatHistoryError(subcommand string, target string) ircError { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"}, + }} +} + +// authError is an authentication error. +type authError struct { + // Internal error cause. This will not be revealed to the user. + err error + // Error cause which can safely be sent to the user without compromising + // security. + reason string +} + +func (err *authError) Error() string { + return err.err.Error() +} + +func (err *authError) Unwrap() error { + return err.err +} + +// authErrorReason returns the user-friendly reason of an authentication +// failure. +func authErrorReason(err error) string { + if authErr, ok := err.(*authError); ok { + return authErr.reason + } else { + return "Authentication failed" + } +} + +func newInvalidUsernameOrPasswordError(err error) error { + return &authError{ + err: err, + reason: "Invalid username or password", + } +} + +func parseBouncerNetID(subcommand, s string) (int64, error) { + id, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"}, + }} + } + return id, nil +} + +func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) { + u, err := network.URL() + if err != nil { + return + } + + hasHostPort := true + switch u.Scheme { + case "ircs": + attrs["tls"] = irc.TagValue("1") + case "irc": + attrs["tls"] = irc.TagValue("0") + default: // e.g. unix:// + hasHostPort = false + } + if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort { + attrs["host"] = irc.TagValue(host) + attrs["port"] = irc.TagValue(port) + } else if hasHostPort { + attrs["host"] = irc.TagValue(u.Host) + } +} + +func getNetworkAttrs(network *network) irc.Tags { + state := "disconnected" + if uc := network.conn; uc != nil { + state = "connected" + } + + attrs := irc.Tags{ + "name": irc.TagValue(network.GetName()), + "state": irc.TagValue(state), + "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)), + } + + if network.Username != "" { + attrs["username"] = irc.TagValue(network.Username) + } + if realname := GetRealname(&network.user.User, &network.Network); realname != "" { + attrs["realname"] = irc.TagValue(realname) + } + + fillNetworkAddrAttrs(attrs, &network.Network) + + return attrs +} + +func networkAddrFromAttrs(attrs irc.Tags) string { + host, ok := attrs.GetTag("host") + if !ok { + return "" + } + + addr := host + if port, ok := attrs.GetTag("port"); ok { + addr += ":" + port + } + + if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" { + addr = "irc://" + tlsStr + } + + return addr +} + +func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error { + addrAttrs := irc.Tags{} + fillNetworkAddrAttrs(addrAttrs, record) + + updateAddr := false + for k, v := range attrs { + s := string(v) + switch k { + case "host", "port", "tls": + updateAddr = true + addrAttrs[k] = v + case "name": + record.Name = s + case "nickname": + record.Nick = s + case "username": + record.Username = s + case "realname": + record.Realname = s + case "pass": + record.Pass = s + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"}, + }} + } + } + + if updateAddr { + record.Addr = networkAddrFromAttrs(addrAttrs) + if record.Addr == "" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"}, + }} + } + } + + return nil +} + +// illegalNickChars is the list of characters forbidden in a nickname. +// +// ' ' and ':' break the IRC message wire format +// '@' and '!' break prefixes +// '*' breaks masks and is the reserved nickname for registration +// '?' breaks masks +// '$' breaks server masks in PRIVMSG/NOTICE +// ',' breaks lists +// '.' is reserved for server names +const illegalNickChars = " :@!*?$,." + +// permanentDownstreamCaps is the list of always-supported downstream +// capabilities. +var permanentDownstreamCaps = map[string]string{ + "batch": "", + "cap-notify": "", + "echo-message": "", + "invite-notify": "", + "message-tags": "", + "server-time": "", + "setname": "", + + "soju.im/bouncer-networks": "", + "soju.im/bouncer-networks-notify": "", + "soju.im/read": "", +} + +// needAllDownstreamCaps is the list of downstream capabilities that +// require support from all upstreams to be enabled +var needAllDownstreamCaps = map[string]string{ + "account-notify": "", + "account-tag": "", + "away-notify": "", + "extended-join": "", + "multi-prefix": "", + + "draft/extended-monitor": "", +} + +// passthroughIsupport is the set of ISUPPORT tokens that are directly passed +// through from the upstream server to downstream clients. +// +// This is only effective in single-upstream mode. +var passthroughIsupport = map[string]bool{ + "AWAYLEN": true, + "BOT": true, + "CHANLIMIT": true, + "CHANMODES": true, + "CHANNELLEN": true, + "CHANTYPES": true, + "CLIENTTAGDENY": true, + "ELIST": true, + "EXCEPTS": true, + "EXTBAN": true, + "HOSTLEN": true, + "INVEX": true, + "KICKLEN": true, + "MAXLIST": true, + "MAXTARGETS": true, + "MODES": true, + "MONITOR": true, + "NAMELEN": true, + "NETWORK": true, + "NICKLEN": true, + "PREFIX": true, + "SAFELIST": true, + "TARGMAX": true, + "TOPICLEN": true, + "USERLEN": true, + "UTF8ONLY": true, + "WHOX": true, +} + +type downstreamSASL struct { + server sasl.Server + plainUsername, plainPassword string + pendingResp bytes.Buffer +} + +type downstreamConn struct { + conn + + id uint64 + + registered bool + user *user + nick string + nickCM string + rawUsername string + networkName string + clientName string + realname string + hostname string + account string // RPL_LOGGEDIN/OUT state + password string // empty after authentication + network *network // can be nil + isMultiUpstream bool + + negotiatingCaps bool + capVersion int + supportedCaps map[string]string + caps map[string]bool + sasl *downstreamSASL + + lastBatchRef uint64 + + monitored casemapMap +} + +func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { + remoteAddr := ic.RemoteAddr().String() + logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} + options := connOptions{Logger: logger} + dc := &downstreamConn{ + conn: *newConn(srv, ic, &options), + id: id, + nick: "*", + nickCM: "*", + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + monitored: newCasemapMap(0), + } + dc.hostname = remoteAddr + if host, _, err := net.SplitHostPort(dc.hostname); err == nil { + dc.hostname = host + } + for k, v := range permanentDownstreamCaps { + dc.supportedCaps[k] = v + } + dc.supportedCaps["sasl"] = "PLAIN" + // TODO: this is racy, we should only enable chathistory after + // authentication and then check that user.msgStore implements + // chatHistoryMessageStore + if srv.Config().LogPath != "" { + dc.supportedCaps["draft/chathistory"] = "" + } + return dc +} + +func (dc *downstreamConn) prefix() *irc.Prefix { + return &irc.Prefix{ + Name: dc.nick, + User: dc.user.Username, + Host: dc.hostname, + } +} + +func (dc *downstreamConn) forEachNetwork(f func(*network)) { + if dc.network != nil { + f(dc.network) + } else if dc.isMultiUpstream { + for _, network := range dc.user.networks { + f(network) + } + } +} + +func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + dc.user.forEachUpstream(func(uc *upstreamConn) { + if dc.network != nil && uc.network != dc.network { + return + } + f(uc) + }) +} + +// upstream returns the upstream connection, if any. If there are zero or if +// there are multiple upstream connections, it returns nil. +func (dc *downstreamConn) upstream() *upstreamConn { + if dc.network == nil { + return nil + } + return dc.network.conn +} + +func isOurNick(net *network, nick string) bool { + // TODO: this doesn't account for nick changes + if net.conn != nil { + return net.casemap(nick) == net.conn.nickCM + } + // We're not currently connected to the upstream connection, so we don't + // know whether this name is our nickname. Best-effort: use the network's + // configured nickname and hope it was the one being used when we were + // connected. + return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network)) +} + +// marshalEntity converts an upstream entity name (ie. channel or nick) into a +// downstream entity name. +// +// This involves adding a "/" suffix if the entity isn't the current +// user. +func (dc *downstreamConn) marshalEntity(net *network, name string) string { + if isOurNick(net, name) { + return dc.nick + } + name = partialCasemap(net.casemap, name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal an entity for another network") + } + return name + } + return name + "/" + net.GetName() +} + +func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix { + if isOurNick(net, prefix.Name) { + return dc.prefix() + } + prefix.Name = partialCasemap(net.casemap, prefix.Name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal a user prefix for another network") + } + return prefix + } + return &irc.Prefix{ + Name: prefix.Name + "/" + net.GetName(), + User: prefix.User, + Host: prefix.Host, + } +} + +// unmarshalEntityNetwork converts a downstream entity name (ie. channel or +// nick) into an upstream entity name. +// +// This involves removing the "/" suffix. +func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) { + if dc.network != nil { + return dc.network, name, nil + } + if !dc.isMultiUpstream { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"}, + }} + } + + var net *network + if i := strings.LastIndexByte(name, '/'); i >= 0 { + network := name[i+1:] + name = name[:i] + + for _, n := range dc.user.networks { + if network == n.GetName() { + net = n + break + } + } + } + + if net == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Missing network suffix in name"}, + }} + } + + return net, name, nil +} + +// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the +// upstream connection and fails if the upstream is disconnected. +func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) { + net, name, err := dc.unmarshalEntityNetwork(name) + if err != nil { + return nil, "", err + } + + if net.conn == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Disconnected from upstream network"}, + }} + } + + return net.conn, name, nil +} + +func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string { + if dc.upstream() != nil { + return text + } + // TODO: smarter parsing that ignores URLs + return strings.ReplaceAll(text, "/"+uc.network.GetName(), "") +} + +func (dc *downstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := dc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (dc *downstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := dc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventDownstreamMessage{msg, dc} + } + + return nil +} + +// SendMessage sends an outgoing message. +// +// This can only called from the user goroutine. +func (dc *downstreamConn) SendMessage(msg *irc.Message) { + if !dc.caps["message-tags"] { + if msg.Command == "TAGMSG" { + return + } + msg = msg.Copy() + for name := range msg.Tags { + supported := false + switch name { + case "time": + supported = dc.caps["server-time"] + case "account": + supported = dc.caps["account"] + } + if !supported { + delete(msg.Tags, name) + } + } + } + if !dc.caps["batch"] && msg.Tags["batch"] != "" { + msg = msg.Copy() + delete(msg.Tags, "batch") + } + if msg.Command == "JOIN" && !dc.caps["extended-join"] { + msg.Params = msg.Params[:1] + } + if msg.Command == "SETNAME" && !dc.caps["setname"] { + return + } + if msg.Command == "AWAY" && !dc.caps["away-notify"] { + return + } + if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] { + return + } + if msg.Command == "READ" && !dc.caps["soju.im/read"] { + return + } + + dc.conn.SendMessage(context.TODO(), msg) +} + +func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) { + dc.lastBatchRef++ + ref := fmt.Sprintf("%v", dc.lastBatchRef) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Tags: tags, + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: append([]string{"+" + ref, typ}, params...), + }) + } + + f(irc.TagValue(ref)) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: []string{"-" + ref}, + }) + } +} + +// sendMessageWithID sends an outgoing message with the specified internal ID. +func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { + dc.SendMessage(msg) + + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// advanceMessageWithID advances history to the specified message ID without +// sending a message. This is useful e.g. for self-messages when echo-message +// isn't enabled. +func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) { + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// ackMsgID acknowledges that a message has been received. +func (dc *downstreamConn) ackMsgID(id string) { + netID, entity, err := parseMsgID(id, nil) + if err != nil { + dc.logger.Printf("failed to ACK message ID %q: %v", id, err) + return + } + + network := dc.user.getNetworkByID(netID) + if network == nil { + return + } + + network.delivered.StoreID(entity, dc.clientName, id) +} + +func (dc *downstreamConn) sendPing(msgID string) { + token := "suika-msgid-" + msgID + dc.SendMessage(&irc.Message{ + Command: "PING", + Params: []string{token}, + }) +} + +func (dc *downstreamConn) handlePong(token string) { + if !strings.HasPrefix(token, "suika-msgid-") { + dc.logger.Printf("received unrecognized PONG token %q", token) + return + } + msgID := strings.TrimPrefix(token, "suika-msgid-") + dc.ackMsgID(msgID) +} + +// marshalMessage re-formats a message coming from an upstream connection so +// that it's suitable for being sent on this downstream connection. Only +// messages that may appear in logs are supported, except MODE messages which +// may only appear in single-upstream mode. +func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message { + msg = msg.Copy() + msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix) + + if dc.network != nil { + return msg + } + + switch msg.Command { + case "PRIVMSG", "NOTICE", "TAGMSG": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "NICK": + // Nick change for another user + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "JOIN", "PART": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "KICK": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + msg.Params[1] = dc.marshalEntity(net, msg.Params[1]) + case "TOPIC": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "QUIT", "SETNAME": + // This space is intentionally left blank + default: + panic(fmt.Sprintf("unexpected %q message", msg.Command)) + } + + return msg +} + +func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + ctx, cancel := dc.conn.NewContext(ctx) + defer cancel() + + ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout) + defer cancel() + + switch msg.Command { + case "QUIT": + return dc.Close() + default: + if dc.registered { + return dc.handleMessageRegistered(ctx, msg) + } else { + return dc.handleMessageUnregistered(ctx, msg) + } + } +} + +func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "NICK": + var nick string + if err := parseMessageParams(msg, &nick); err != nil { + return err + } + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, nick, "contains illegal characters"}, + }} + } + nickCM := casemapASCII(nick) + if nickCM == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"}, + }} + } + dc.nick = nick + dc.nickCM = nickCM + case "USER": + if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil { + return err + } + case "PASS": + if err := parseMessageParams(msg, &dc.password); err != nil { + return err + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "AUTHENTICATE": + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } else if credentials == nil { + break + } + + if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil { + dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err) + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, authErrorReason(err)}, + }) + break + } + + // Technically we should send RPL_LOGGEDIN here. However we use + // RPL_LOGGEDIN to mirror the upstream connection status. Let's + // see how many clients that breaks. See: + // https://github.com/ircv3/ircv3-specifications/pull/476 + dc.endSASL(nil) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + + if dc.user == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"}, + }} + } + + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + var match *network + for _, net := range dc.user.networks { + if net.ID == id { + match = net + break + } + } + if match == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"}, + }} + } + + dc.networkName = match.GetName() + } + default: + dc.logger.Printf("unhandled message: %v", msg) + return newUnknownCommandError(msg.Command) + } + if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps { + return dc.register(ctx) + } + return nil +} + +func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { + cmd = strings.ToUpper(cmd) + + switch cmd { + case "LS": + if len(args) > 0 { + var err error + if dc.capVersion, err = strconv.Atoi(args[0]); err != nil { + return err + } + } + if !dc.registered && dc.capVersion >= 302 { + // Let downstream show everything it supports, and trim + // down the available capabilities when upstreams are + // known. + for k, v := range needAllDownstreamCaps { + dc.supportedCaps[k] = v + } + } + + caps := make([]string, 0, len(dc.supportedCaps)) + for k, v := range dc.supportedCaps { + if dc.capVersion >= 302 && v != "" { + caps = append(caps, k+"="+v) + } else { + caps = append(caps, k) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LS", strings.Join(caps, " ")}, + }) + + if dc.capVersion >= 302 { + // CAP version 302 implicitly enables cap-notify + dc.caps["cap-notify"] = true + } + + if !dc.registered { + dc.negotiatingCaps = true + } + case "LIST": + var caps []string + for name, enabled := range dc.caps { + if enabled { + caps = append(caps, name) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LIST", strings.Join(caps, " ")}, + }) + case "REQ": + if len(args) == 0 { + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"}, + }} + } + + // TODO: atomically ack/nak the whole capability set + caps := strings.Fields(args[0]) + ack := true + for _, name := range caps { + name = strings.ToLower(name) + enable := !strings.HasPrefix(name, "-") + if !enable { + name = strings.TrimPrefix(name, "-") + } + + if enable == dc.caps[name] { + continue + } + + _, ok := dc.supportedCaps[name] + if !ok { + ack = false + break + } + + if name == "cap-notify" && dc.capVersion >= 302 && !enable { + // cap-notify cannot be disabled with CAP version 302 + ack = false + break + } + + dc.caps[name] = enable + } + + reply := "NAK" + if ack { + reply = "ACK" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, reply, args[0]}, + }) + + if !dc.registered { + dc.negotiatingCaps = true + } + case "END": + dc.negotiatingCaps = false + default: + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Unknown CAP command"}, + }} + } + return nil +} + +func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) { + defer func() { + if err != nil { + dc.sasl = nil + } + }() + + if !dc.caps["sasl"] { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, + }} + } + if len(msg.Params) == 0 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Missing AUTHENTICATE argument"}, + }} + } + if msg.Params[0] == "*" { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }} + } + + var resp []byte + if dc.sasl == nil { + mech := strings.ToUpper(msg.Params[0]) + var server sasl.Server + switch mech { + case "PLAIN": + server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { + dc.sasl.plainUsername = username + dc.sasl.plainPassword = password + return nil + })) + default: + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)}, + }} + } + + dc.sasl = &downstreamSASL{server: server} + } else { + chunk := msg.Params[0] + if chunk == "+" { + chunk = "" + } + + if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Response too long"}, + }} + } + + dc.sasl.pendingResp.WriteString(chunk) + + if len(chunk) == maxSASLLength { + return nil, nil // Multi-line response, wait for the next command + } + + resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String()) + if err != nil { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Invalid base64-encoded response"}, + }} + } + + dc.sasl.pendingResp.Reset() + } + + challenge, done, err := dc.sasl.server.Next(resp) + if err != nil { + return nil, err + } else if done { + return dc.sasl, nil + } else { + challengeStr := "+" + if len(challenge) > 0 { + challengeStr = base64.StdEncoding.EncodeToString(challenge) + } + + // TODO: multi-line messages + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "AUTHENTICATE", + Params: []string{challengeStr}, + }) + return nil, nil + } +} + +func (dc *downstreamConn) endSASL(msg *irc.Message) { + if dc.sasl == nil { + return + } + + dc.sasl = nil + + if msg != nil { + dc.SendMessage(msg) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_SASLSUCCESS, + Params: []string{dc.nick, "SASL authentication successful"}, + }) + } +} + +func (dc *downstreamConn) setSupportedCap(name, value string) { + prevValue, hasPrev := dc.supportedCaps[name] + changed := !hasPrev || prevValue != value + dc.supportedCaps[name] = value + + if !dc.caps["cap-notify"] || !changed { + return + } + + cap := name + if value != "" && dc.capVersion >= 302 { + cap = name + "=" + value + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "NEW", cap}, + }) +} + +func (dc *downstreamConn) unsetSupportedCap(name string) { + _, hasPrev := dc.supportedCaps[name] + delete(dc.supportedCaps, name) + delete(dc.caps, name) + + if !dc.caps["cap-notify"] || !hasPrev { + return + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "DEL", name}, + }) +} + +func (dc *downstreamConn) updateSupportedCaps() { + supportedCaps := make(map[string]bool) + for cap := range needAllDownstreamCaps { + supportedCaps[cap] = true + } + dc.forEachUpstream(func(uc *upstreamConn) { + for cap, supported := range supportedCaps { + supportedCaps[cap] = supported && uc.caps[cap] + } + }) + + for cap, supported := range supportedCaps { + if supported { + dc.setSupportedCap(cap, needAllDownstreamCaps[cap]) + } else { + dc.unsetSupportedCap(cap) + } + } + + if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") { + dc.setSupportedCap("sasl", "PLAIN") + } else if dc.network != nil { + dc.unsetSupportedCap("sasl") + } + + if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { + // Strip "before-connect", because we require downstreams to be fully + // connected before attempting account registration. + values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") + for i, v := range values { + if v == "before-connect" { + values = append(values[:i], values[i+1:]...) + break + } + } + dc.setSupportedCap("draft/account-registration", strings.Join(values, ",")) + } else { + dc.unsetSupportedCap("draft/account-registration") + } + + if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil { + dc.setSupportedCap("draft/event-playback", "") + } else { + dc.unsetSupportedCap("draft/event-playback") + } +} + +func (dc *downstreamConn) updateNick() { + if uc := dc.upstream(); uc != nil && uc.nick != dc.nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{uc.nick}, + }) + dc.nick = uc.nick + dc.nickCM = casemapASCII(dc.nick) + } +} + +func (dc *downstreamConn) updateRealname() { + if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{uc.realname}, + }) + dc.realname = uc.realname + } +} + +func (dc *downstreamConn) updateAccount() { + var account string + if dc.network == nil { + account = dc.user.Username + } else if uc := dc.upstream(); uc != nil { + account = uc.account + } else { + return + } + + if dc.account == account || !dc.caps["sasl"] { + return + } + + if account != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDIN, + Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDOUT, + Params: []string{dc.nick, dc.prefix().String(), "You are logged out"}, + }) + } + + dc.account = account +} + +func sanityCheckServer(ctx context.Context, addr string) error { + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr) + if err != nil { + return err + } + + return conn.Close() +} + +func unmarshalUsername(rawUsername string) (username, client, network string) { + username = rawUsername + + i := strings.IndexAny(username, "/@") + j := strings.LastIndexAny(username, "/@") + if i >= 0 { + username = rawUsername[:i] + } + if j >= 0 { + if rawUsername[j] == '@' { + client = rawUsername[j+1:] + } else { + network = rawUsername[j+1:] + } + } + if i >= 0 && j >= 0 && i < j { + if rawUsername[i] == '@' { + client = rawUsername[i+1 : j] + } else { + network = rawUsername[i+1 : j] + } + } + + return username, client, network +} + +func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { + username, clientName, networkName := unmarshalUsername(username) + + u, err := dc.srv.db.GetUser(ctx, username) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err)) + } + + // Password auth disabled + if u.Password == "" { + return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled")) + } + + err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password")) + } + + dc.user = dc.srv.getUser(username) + if dc.user == nil { + return fmt.Errorf("user not active") + } + dc.clientName = clientName + dc.networkName = networkName + return nil +} + +func (dc *downstreamConn) register(ctx context.Context) error { + if dc.registered { + panic("tried to register twice") + } + + if dc.sasl != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + } + + password := dc.password + dc.password = "" + if dc.user == nil { + if password == "" { + if dc.caps["sasl"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"}, + }} + } else { + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, "Authentication required"}, + }} + } + } + + if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil { + dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, authErrorReason(err)}, + }} + } + } + + _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername) + if dc.clientName == "" { + dc.clientName = fallbackClientName + } else if fallbackClientName != "" && dc.clientName != fallbackClientName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Client name mismatch in usernames"}, + }} + } + if dc.networkName == "" { + dc.networkName = fallbackNetworkName + } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Network name mismatch in usernames"}, + }} + } + + dc.registered = true + dc.logger.Printf("registration complete for user %q", dc.user.Username) + return nil +} + +func (dc *downstreamConn) loadNetwork(ctx context.Context) error { + if dc.networkName == "" { + return nil + } + + network := dc.user.getNetwork(dc.networkName) + if network == nil { + addr := dc.networkName + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + dc.logger.Printf("trying to connect to new network %q", addr) + if err := sanityCheckServer(ctx, addr); err != nil { + dc.logger.Printf("failed to connect to %q: %v", addr, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)}, + }} + } + + // Some clients only allow specifying the nickname (and use the + // nickname as a username too). Strip the network name from the + // nickname when auto-saving networks. + nick, _, _ := unmarshalUsername(dc.nick) + + dc.logger.Printf("auto-saving network %q", dc.networkName) + var err error + network, err = dc.user.createNetwork(ctx, &Network{ + Addr: dc.networkName, + Nick: nick, + Enabled: true, + }) + if err != nil { + return err + } + } + + dc.network = network + return nil +} + +func (dc *downstreamConn) welcome(ctx context.Context) error { + if dc.user == nil || !dc.registered { + panic("tried to welcome an unregistered connection") + } + + remoteAddr := dc.conn.RemoteAddr().String() + dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)} + + // TODO: doing this might take some time. We should do it in dc.register + // instead, but we'll potentially be adding a new network and this must be + // done in the user goroutine. + if err := dc.loadNetwork(ctx); err != nil { + return err + } + + if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream { + dc.isMultiUpstream = true + } + + dc.updateSupportedCaps() + + isupport := []string{ + fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit), + "CASEMAPPING=ascii", + } + + if dc.network != nil { + isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID)) + } + if title := dc.srv.Config().Title; dc.network == nil && title != "" { + isupport = append(isupport, "NETWORK="+encodeISUPPORT(title)) + } + if dc.network == nil && !dc.isMultiUpstream { + isupport = append(isupport, "WHOX") + } + + if uc := dc.upstream(); uc != nil { + for k := range passthroughIsupport { + v, ok := uc.isupport[k] + if !ok { + continue + } + if v != nil { + isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v)) + } else { + isupport = append(isupport, k) + } + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WELCOME, + Params: []string{dc.nick, "Welcome to suika, " + dc.nick}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_YOURHOST, + Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MYINFO, + Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) { + dc.SendMessage(msg) + } + if uc := dc.upstream(); uc != nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + string(uc.modes)}, + }) + } + if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+o"}, + }) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + + if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil { + for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) { + dc.SendMessage(msg) + } + } else { + motdHint := "No MOTD" + if dc.network != nil { + motdHint = "Use /motd to read the message of the day" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOMOTD, + Params: []string{dc.nick, motdHint}, + }) + } + + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + } + + dc.forEachUpstream(func(uc *upstreamConn) { + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if !ch.complete { + continue + } + record := uc.network.channels.Value(ch.Name) + if record != nil && record.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)}, + }) + + forwardChannel(ctx, dc, ch) + } + }) + + dc.forEachNetwork(func(net *network) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + // Only send history if we're the first connected client with that name + // for the network + firstClient := true + dc.user.forEachDownstream(func(c *downstreamConn) { + if c != dc && c.clientName == dc.clientName && c.network == dc.network { + firstClient = false + } + }) + if firstClient { + net.delivered.ForEachTarget(func(target string) { + lastDelivered := net.delivered.LoadID(target, dc.clientName) + if lastDelivered == "" { + return + } + + dc.sendTargetBacklog(ctx, net, target, lastDelivered) + + // Fast-forward history to last message + targetCM := net.casemap(target) + lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now()) + if err != nil { + dc.logger.Printf("failed to get last message ID: %v", err) + return + } + net.delivered.StoreID(target, dc.clientName, lastID) + }) + } + }) + + return nil +} + +// messageSupportsBacklog checks whether the provided message can be sent as +// part of an history batch. +func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool { + // Don't replay all messages, because that would mess up client + // state. For instance we just sent the list of users, sending + // PART messages for one of these users would be incorrect. + switch msg.Command { + case "PRIVMSG", "NOTICE": + return true + } + return false +} + +func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + ch := net.channels.Value(target) + + ctx, cancel := context.WithTimeout(ctx, backlogTimeout) + defer cancel() + + targetCM := net.casemap(target) + history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit) + if err != nil { + dc.logger.Printf("failed to send backlog for %q: %v", target, err) + return + } + + dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + if ch != nil && ch.Detached { + if net.detachedMessageNeedsRelay(ch, msg) { + dc.relayDetachedMessage(net, msg) + } + } else { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, net)) + } + } + }) +} + +func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return + } + + sender := msg.Prefix.Name + target, text := msg.Params[0], msg.Params[1] + if net.isHighlight(msg) { + sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } else { + sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } +} + +func (dc *downstreamConn) runUntilRegistered() error { + ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout) + defer cancel() + + // Close the connection with an error if the deadline is exceeded + go func() { + <-ctx.Done() + if err := ctx.Err(); err == context.DeadlineExceeded { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "ERROR", + Params: []string{"Connection registration timed out"}, + }) + dc.Close() + } + }() + + for !dc.registered { + msg, err := dc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read IRC command: %w", err) + } + + err = dc.handleMessage(ctx, msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + return fmt.Errorf("failed to handle IRC command %q: %v", msg, err) + } + } + + return nil +} + +func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "PING": + var source, destination string + if err := parseMessageParams(msg, &source); err != nil { + return err + } + if len(msg.Params) > 1 { + destination = msg.Params[1] + } + hostname := dc.srv.Config().Hostname + if destination != "" && destination != hostname { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHSERVER, + Params: []string{dc.nick, destination, "No such server"}, + }} + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "PONG", + Params: []string{hostname, source}, + }) + return nil + case "PONG": + if len(msg.Params) == 0 { + return newNeedMoreParamsError(msg.Command) + } + token := msg.Params[len(msg.Params)-1] + dc.handlePong(token) + case "USER": + return ircError{&irc.Message{ + Command: irc.ERR_ALREADYREGISTERED, + Params: []string{dc.nick, "You may not reregister"}, + }} + case "NICK": + var rawNick string + if err := parseMessageParams(msg, &rawNick); err != nil { + return err + } + + nick := rawNick + var upstream *upstreamConn + if dc.upstream() == nil { + uc, unmarshaledNick, err := dc.unmarshalEntity(nick) + if err == nil { // NICK nick/network: NICK only on a specific upstream + upstream = uc + nick = unmarshaledNick + } + } + + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, rawNick, "contains illegal characters"}, + }} + } + if casemapASCII(nick) == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"}, + }} + } + + var err error + dc.forEachNetwork(func(n *network) { + if err != nil || (upstream != nil && upstream.network != n) { + return + } + n.Nick = nick + err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network) + }) + if err != nil { + return err + } + + dc.forEachUpstream(func(uc *upstreamConn) { + if upstream != nil && upstream != uc { + return + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NICK", + Params: []string{nick}, + }) + }) + + if dc.upstream() == nil && upstream == nil && dc.nick != nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{nick}, + }) + dc.nick = nick + dc.nickCM = casemapASCII(dc.nick) + } + case "SETNAME": + var realname string + if err := parseMessageParams(msg, &realname); err != nil { + return err + } + + // If the client just resets to the default, just wipe the per-network + // preference + storeRealname := realname + if realname == dc.user.Realname { + storeRealname = "" + } + + var storeErr error + var needUpdate []Network + dc.forEachNetwork(func(n *network) { + // We only need to call updateNetwork for upstreams that don't + // support setname + if uc := n.conn; uc != nil && uc.caps["setname"] { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "SETNAME", + Params: []string{realname}, + }) + + n.Realname = storeRealname + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil { + dc.logger.Printf("failed to store network realname: %v", err) + storeErr = err + } + return + } + + record := n.Network // copy network record because we'll mutate it + record.Realname = storeRealname + needUpdate = append(needUpdate, record) + }) + + // Walk the network list as a second step, because updateNetwork + // mutates the original list + for _, record := range needUpdate { + if _, err := dc.user.updateNetwork(ctx, &record); err != nil { + dc.logger.Printf("failed to update network realname: %v", err) + storeErr = err + } + } + if storeErr != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"}, + }} + } + + if dc.upstream() == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{realname}, + }) + } + case "JOIN": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var keys []string + if len(msg.Params) > 1 { + keys = strings.Split(msg.Params[1], ",") + } + + for i, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + var key string + if len(keys) > i { + key = keys[i] + } + + if !uc.isChannel(upstreamName) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, "Not a channel name"}, + }) + continue + } + + // Most servers ignore duplicate JOIN messages. We ignore them here + // because some clients automatically send JOIN messages in bulk + // when reconnecting to the bouncer. We don't want to flood the + // upstream connection with these. + if !uc.channels.Has(upstreamName) { + params := []string{upstreamName} + if key != "" { + params = append(params, key) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "JOIN", + Params: params, + }) + } + + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + // Don't clear the channel key if there's one set + // TODO: add a way to unset the channel key + if key != "" { + ch.Key = key + } + uc.network.attach(ctx, ch) + } else { + ch = &Channel{ + Name: upstreamName, + Key: key, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } + case "PART": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + for _, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if strings.EqualFold(reason, "detach") { + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + uc.network.detach(ch) + } else { + ch = &Channel{ + Name: name, + Detached: true, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } else { + params := []string{upstreamName} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "PART", + Params: params, + }) + + if err := uc.network.deleteChannel(ctx, upstreamName); err != nil { + dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) + } + } + } + case "KICK": + var channelStr, userStr string + if err := parseMessageParams(msg, &channelStr, &userStr); err != nil { + return err + } + + channels := strings.Split(channelStr, ",") + users := strings.Split(userStr, ",") + + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + + if len(channels) != 1 && len(channels) != len(users) { + return ircError{&irc.Message{ + Command: irc.ERR_BADCHANMASK, + Params: []string{dc.nick, channelStr, "Bad channel mask"}, + }} + } + + for i, user := range users { + var channel string + if len(channels) == 1 { + channel = channels[0] + } else { + channel = channels[i] + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + params := []string{upstreamChannel, upstreamUser} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "KICK", + Params: params, + }) + } + case "MODE": + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + + var modeStr string + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + if casemapASCII(name) == dc.nickCM { + if modeStr != "" { + if uc := dc.upstream(); uc != nil { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: []string{uc.nick, modeStr}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_UMODEUNKNOWNFLAG, + Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"}, + }) + } + } else { + var userMode string + if uc := dc.upstream(); uc != nil { + userMode = string(uc.modes) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + userMode}, + }) + } + return nil + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if !uc.isChannel(upstreamName) { + return ircError{&irc.Message{ + Command: irc.ERR_USERSDONTMATCH, + Params: []string{dc.nick, "Cannot change mode for other users"}, + }} + } + + if modeStr != "" { + params := []string{upstreamName, modeStr} + params = append(params, msg.Params[2:]...) + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: params, + }) + } else { + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "No such channel"}, + }} + } + + if ch.modes == nil { + // we haven't received the initial RPL_CHANNELMODEIS yet + // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS + return nil + } + + modeStr, modeParams := ch.modes.Format() + params := []string{dc.nick, name, modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + if ch.creationTime != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, name, ch.creationTime}, + }) + } + } + case "TOPIC": + var channel string + if err := parseMessageParams(msg, &channel); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + if len(msg.Params) > 1 { // setting topic + topic := msg.Params[1] + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "TOPIC", + Params: []string{upstreamName, topic}, + }) + } else { // getting topic + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, upstreamName, "No such channel"}, + }} + } + sendTopic(dc, ch) + } + case "LIST": + network := dc.network + if network == nil && len(msg.Params) > 0 { + var err error + network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0]) + if err != nil { + return err + } + } + if network == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"}, + }) + return nil + } + + uc := network.conn + if uc == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Disconnected from upstream server"}, + }) + return nil + } + + uc.enqueueCommand(dc, msg) + case "NAMES": + if len(msg.Params) == 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, "*", "End of /NAMES list"}, + }) + return nil + } + + channels := strings.Split(msg.Params[0], ",") + for _, channel := range channels { + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ch := uc.channels.Value(upstreamName) + if ch != nil { + sendNames(dc, ch) + } else { + // NAMES on a channel we have not joined, ask upstream + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NAMES", + Params: []string{upstreamName}, + }) + } + } + // For WHOX docs, see: + // - http://faerion.sourceforge.net/doc/irc/whox.var + // - https://github.com/quakenet/snircd/blob/master/doc/readme.who + // Note, many features aren't widely implemented, such as flags and mask2 + case "WHO": + if len(msg.Params) == 0 { + // TODO: support WHO without parameters + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, "*", "End of /WHO list"}, + }) + return nil + } + + // Clients will use the first mask to match RPL_ENDOFWHO + endOfWhoToken := msg.Params[0] + + // TODO: add support for WHOX mask2 + mask := msg.Params[0] + var options string + if len(msg.Params) > 1 { + options = msg.Params[1] + } + + optionsParts := strings.SplitN(options, "%", 2) + // TODO: add support for WHOX flags in optionsParts[0] + var fields, whoxToken string + if len(optionsParts) == 2 { + optionsParts := strings.SplitN(optionsParts[1], ",", 2) + fields = strings.ToLower(optionsParts[0]) + if len(optionsParts) == 2 && strings.Contains(fields, "t") { + whoxToken = optionsParts[1] + } + } + + // TODO: support mixed bouncer/upstream WHO queries + maskCM := casemapASCII(mask) + if dc.network == nil && maskCM == dc.nickCM { + // TODO: support AWAY (H/G) in self WHO reply + flags := "H" + if dc.user.Admin { + flags += "*" + } + info := whoxInfo{ + Token: whoxToken, + Username: dc.user.Username, + Hostname: dc.hostname, + Server: dc.srv.Config().Hostname, + Nickname: dc.nick, + Flags: flags, + Account: dc.user.Username, + Realname: dc.realname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + if maskCM == serviceNickCM { + info := whoxInfo{ + Token: whoxToken, + Username: servicePrefix.User, + Hostname: servicePrefix.Host, + Server: dc.srv.Config().Hostname, + Nickname: serviceNick, + Flags: "H*", + Account: serviceNick, + Realname: serviceRealname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + + // TODO: properly support WHO masks + uc, upstreamMask, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + params := []string{upstreamMask} + if options != "" { + params = append(params, options) + } + + uc.enqueueCommand(dc, &irc.Message{ + Command: "WHO", + Params: params, + }) + case "WHOIS": + if len(msg.Params) == 0 { + return ircError{&irc.Message{ + Command: irc.ERR_NONICKNAMEGIVEN, + Params: []string{dc.nick, "No nickname given"}, + }} + } + + var target, mask string + if len(msg.Params) == 1 { + target = "" + mask = msg.Params[0] + } else { + target = msg.Params[0] + mask = msg.Params[1] + } + // TODO: support multiple WHOIS users + if i := strings.IndexByte(mask, ','); i >= 0 { + mask = mask[:i] + } + + if dc.network == nil && casemapASCII(mask) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"}, + }) + if dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, dc.nick, "is a bouncer administrator"}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, dc.nick, "End of /WHOIS list"}, + }) + return nil + } + if casemapASCII(mask) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, serviceNick, "is the bouncer service"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, serviceNick, "End of /WHOIS list"}, + }) + return nil + } + + // TODO: support WHOIS masks + uc, upstreamNick, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + var params []string + if target != "" { + if target == mask { // WHOIS nick nick + params = []string{upstreamNick, upstreamNick} + } else { + params = []string{target, upstreamNick} + } + } else { + params = []string{upstreamNick} + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "WHOIS", + Params: params, + }) + case "PRIVMSG", "NOTICE": + var targetsStr, text string + if err := parseMessageParams(msg, &targetsStr, &text); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) { + // "$" means a server mask follows. If it's the bouncer's + // hostname, broadcast the message to all bouncer users. + if !dc.user.Admin { + return ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_BADMASK, + Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"}, + }} + } + + dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text) + + broadcastTags := tags.Copy() + broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now())) + broadcastMsg := &irc.Message{ + Tags: broadcastTags, + Prefix: servicePrefix, + Command: msg.Command, + Params: []string{name, text}, + } + dc.srv.forEachUser(func(u *user) { + u.events <- eventBroadcast{broadcastMsg} + }) + continue + } + + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + continue + } + + if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM { + if dc.caps["echo-message"] { + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + dc.SendMessage(&irc.Message{ + Tags: echoTags, + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + } + handleServicePRIVMSG(ctx, dc, text) + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" { + dc.handleNickServPRIVMSG(ctx, uc, text) + } + + unmarshaledText := text + if uc.isChannel(upstreamName) { + unmarshaledText = dc.unmarshalText(uc, text) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: msg.Command, + Params: []string{upstreamName, unmarshaledText}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: msg.Command, + Params: []string{upstreamName, text}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "TAGMSG": + var targetsStr string + if err := parseMessageParams(msg, &targetsStr); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: "TAGMSG", + Params: []string{name}, + }) + continue + } + + if casemapASCII(name) == serviceNickCM { + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + if _, ok := uc.caps["message-tags"]; !ok { + continue + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: "TAGMSG", + Params: []string{upstreamName}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: "TAGMSG", + Params: []string{upstreamName}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "INVITE": + var user, channel string + if err := parseMessageParams(msg, &user, &channel); err != nil { + return err + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "INVITE", + Params: []string{upstreamUser, upstreamChannel}, + }) + case "AUTHENTICATE": + // Post-connection-registration AUTHENTICATE is unsupported in + // multi-upstream mode, or if the upstream doesn't support SASL + uc := dc.upstream() + if uc == nil || !uc.caps["sasl"] { + return ircError{&irc.Message{ + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Upstream network authentication not supported"}, + }} + } + + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } + + if credentials != nil { + if uc.saslClient != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Another authentication attempt is already in progress"}, + }) + return nil + } + + uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername) + uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword) + uc.enqueueCommand(dc, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"PLAIN"}, + }) + } + case "REGISTER", "VERIFY": + // Check number of params here, since we'll use that to save the + // credentials on command success + if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) { + return newNeedMoreParamsError(msg.Command) + } + + uc := dc.upstream() + if uc == nil || !uc.caps["draft/account-registration"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, + }} + } + + uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0]) + uc.enqueueCommand(dc, msg) + case "MONITOR": + // MONITOR is unsupported in multi-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + if _, ok := uc.isupport["MONITOR"]; !ok { + return newUnknownCommandError(msg.Command) + } + + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "+", "-": + var targets string + if err := parseMessageParams(msg, nil, &targets); err != nil { + return err + } + for _, target := range strings.Split(targets, ",") { + if subcommand == "+" { + // Hard limit, just to avoid having downstreams fill our map + if len(dc.monitored.innerMap) >= 1000 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_MONLISTFULL, + Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"}, + }) + continue + } + + dc.monitored.SetValue(target, nil) + + if uc.monitored.Has(target) { + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } else { + dc.monitored.Delete(target) + } + } + uc.updateMonitor() + case "C": // clear + dc.monitored = newCasemapMap(0) + uc.updateMonitor() + case "L": // list + // TODO: be less lazy and pack the list + for _, entry := range dc.monitored.innerMap { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MONLIST, + Params: []string{dc.nick, entry.originalKey}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFMONLIST, + Params: []string{dc.nick, "End of MONITOR list"}, + }) + case "S": // status + // TODO: be less lazy and pack the lists + for _, entry := range dc.monitored.innerMap { + target := entry.originalKey + + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } + case "CHATHISTORY": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + var target, limitStr string + var boundsStr [2]string + switch subcommand { + case "AFTER", "BEFORE", "LATEST": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil { + return err + } + case "BETWEEN": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + case "TARGETS": + if dc.network == nil { + // Either an unbound bouncer network, in which case we should return no targets, + // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet. + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {}) + return nil + } + if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + default: + // TODO: support AROUND + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"}, + }} + } + + // We don't save history for our service + if casemapASCII(target) == serviceNickCM { + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {}) + return nil + } + + store, ok := dc.user.msgStore.(chatHistoryMessageStore) + if !ok { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{dc.nick, "CHATHISTORY", "Unknown command"}, + }} + } + + network, entity, err := dc.unmarshalEntityNetwork(target) + if err != nil { + return err + } + entity = network.casemap(entity) + + // TODO: support msgid criteria + var bounds [2]time.Time + bounds[0] = parseChatHistoryBound(boundsStr[0]) + if subcommand == "LATEST" && boundsStr[0] == "*" { + bounds[0] = time.Now() + } else if bounds[0].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"}, + }} + } + + if boundsStr[1] != "" { + bounds[1] = parseChatHistoryBound(boundsStr[1]) + if bounds[1].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"}, + }} + } + } + + limit, err := strconv.Atoi(limitStr) + if err != nil || limit < 0 || limit > chatHistoryLimit { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"}, + }} + } + + eventPlayback := dc.caps["draft/event-playback"] + + var history []*irc.Message + switch subcommand { + case "BEFORE", "LATEST": + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) + case "AFTER": + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) + case "BETWEEN": + if bounds[0].Before(bounds[1]) { + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } else { + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } + case "TARGETS": + // TODO: support TARGETS in multi-upstream mode + targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback) + if err != nil { + dc.logger.Printf("failed fetching targets for chathistory: %v", err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"}, + }} + } + + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) { + for _, target := range targets { + if ch := network.channels.Value(target.Name); ch != nil && ch.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "CHATHISTORY", + Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)}, + }) + } + }) + + return nil + } + if err != nil { + dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err) + return newChatHistoryError(subcommand, target) + } + + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, network)) + } + }) + case "READ": + var target, criteria string + if err := parseMessageParams(msg, &target); err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"}, + }} + } + if len(msg.Params) > 1 { + criteria = msg.Params[1] + } + + // We don't save read receipts for our service + if casemapASCII(target) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{target, "*"}, + }) + return nil + } + + uc, entity, err := dc.unmarshalEntity(target) + if err != nil { + return err + } + entityCM := uc.network.casemap(entity) + + r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } else if r == nil { + r = &ReadReceipt{ + Target: entityCM, + } + } + + broadcast := false + if len(criteria) > 0 { + // TODO: support msgid criteria + criteriaParts := strings.SplitN(criteria, "=", 2) + if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"}, + }} + } + + timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1]) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"}, + }} + } + now := time.Now() + if timestamp.After(now) { + timestamp = now + } + if r.Timestamp.Before(timestamp) { + r.Timestamp = timestamp + if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil { + dc.logger.Printf("failed to store receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } + broadcast = true + } + } + + timestampStr := "*" + if !r.Timestamp.IsZero() { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + uc.forEachDownstream(func(d *downstreamConn) { + if broadcast || dc.id == d.id { + d.SendMessage(&irc.Message{ + Prefix: d.prefix(), + Command: "READ", + Params: []string{d.marshalEntity(uc.network, entity), timestampStr}, + }) + } + }) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"}, + }} + case "LISTNETWORKS": + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + case "ADDNETWORK": + var attrsStr string + if err := parseMessageParams(msg, nil, &attrsStr); err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + record := &Network{Nick: dc.nick, Enabled: true} + if err := updateNetworkAttrs(record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)}, + }) + case "CHANGENETWORK": + var idStr, attrsStr string + if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + record := net.Network // copy network record because we'll mutate it + if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + _, err = dc.user.updateNetwork(ctx, &record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"CHANGENETWORK", idStr}, + }) + case "DELNETWORK": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"DELNETWORK", idStr}, + }) + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"}, + }} + } + default: + dc.logger.Printf("unhandled message: %v", msg) + + // Only forward unknown commands in single-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + + uc.SendMessageLabeled(ctx, dc.id, msg) + } + return nil +} + +func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) { + username, password, ok := parseNickServCredentials(text, uc.nick) + if ok { + uc.network.autoSaveSASLPlain(ctx, username, password) + } +} + +func parseNickServCredentials(text, nick string) (username, password string, ok bool) { + fields := strings.Fields(text) + if len(fields) < 2 { + return "", "", false + } + cmd := strings.ToUpper(fields[0]) + params := fields[1:] + switch cmd { + case "REGISTER": + username = nick + password = params[0] + case "IDENTIFY": + if len(params) == 1 { + username = nick + password = params[0] + } else { + username = params[0] + password = params[1] + } + case "SET": + if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") { + username = nick + password = params[1] + } + default: + return "", "", false + } + return username, password, true +} diff --git a/branches/origin-master/go.mod b/branches/origin-master/go.mod new file mode 100644 index 0000000..cc08167 --- /dev/null +++ b/branches/origin-master/go.mod @@ -0,0 +1,39 @@ +module marisa.chaotic.ninja/suika + +go 1.20 + +require ( + git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 + git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 + github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead + github.com/lib/pq v1.10.7 + golang.org/x/crypto v0.7.0 + golang.org/x/term v0.6.0 + golang.org/x/time v0.3.0 + gopkg.in/irc.v3 v3.1.4 + modernc.org/sqlite v1.21.0 +) + +require ( + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/stretchr/testify v1.8.0 // indirect + golang.org/x/mod v0.3.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + lukechampine.com/uint128 v1.2.0 // indirect + modernc.org/cc/v3 v3.40.0 // indirect + modernc.org/ccgo/v3 v3.16.13 // indirect + modernc.org/libc v1.22.3 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/opt v0.1.3 // indirect + modernc.org/strutil v1.1.3 // indirect + modernc.org/token v1.0.1 // indirect +) diff --git a/branches/origin-master/go.sum b/branches/origin-master/go.sum new file mode 100644 index 0000000..090246d --- /dev/null +++ b/branches/origin-master/go.sum @@ -0,0 +1,105 @@ +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao= +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U= +git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc= +gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= +modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= +modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= +modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= +modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= +modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= +modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY= +modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow= +modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI= +modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= +modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= +modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws= +modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= +modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= diff --git a/branches/origin-master/irc.go b/branches/origin-master/irc.go new file mode 100644 index 0000000..1799e32 --- /dev/null +++ b/branches/origin-master/irc.go @@ -0,0 +1,811 @@ +package suika + +import ( + "fmt" + "sort" + "strings" + "time" + "unicode" + "unicode/utf8" + + "gopkg.in/irc.v3" +) + +const ( + rpl_statsping = "246" + rpl_localusers = "265" + rpl_globalusers = "266" + rpl_creationtime = "329" + rpl_topicwhotime = "333" + rpl_whospcrpl = "354" + rpl_whoisaccount = "330" + err_invalidcapcmd = "410" +) + +const ( + maxMessageLength = 512 + maxMessageParams = 15 + maxSASLLength = 400 +) + +// The server-time layout, as defined in the IRCv3 spec. +const serverTimeLayout = "2006-01-02T15:04:05.000Z" + +func formatServerTime(t time.Time) string { + return t.UTC().Format(serverTimeLayout) +} + +type userModes string + +func (ms userModes) Has(c byte) bool { + return strings.IndexByte(string(ms), c) >= 0 +} + +func (ms *userModes) Add(c byte) { + if !ms.Has(c) { + *ms += userModes(c) + } +} + +func (ms *userModes) Del(c byte) { + i := strings.IndexByte(string(*ms), c) + if i >= 0 { + *ms = (*ms)[:i] + (*ms)[i+1:] + } +} + +func (ms *userModes) Apply(s string) error { + var plusMinus byte + for i := 0; i < len(s); i++ { + switch c := s[i]; c { + case '+', '-': + plusMinus = c + default: + switch plusMinus { + case '+': + ms.Add(c) + case '-': + ms.Del(c) + default: + return fmt.Errorf("malformed modestring %q: missing plus/minus", s) + } + } + } + return nil +} + +type channelModeType byte + +// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message +const ( + // modes that add or remove an address to or from a list + modeTypeA channelModeType = iota + // modes that change a setting on a channel, and must always have a parameter + modeTypeB + // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset + modeTypeC + // modes that change a setting on a channel, and must not have a parameter + modeTypeD +) + +var stdChannelModes = map[byte]channelModeType{ + 'b': modeTypeA, // ban list + 'e': modeTypeA, // ban exception list + 'I': modeTypeA, // invite exception list + 'k': modeTypeB, // channel key + 'l': modeTypeC, // channel user limit + 'i': modeTypeD, // channel is invite-only + 'm': modeTypeD, // channel is moderated + 'n': modeTypeD, // channel has no external messages + 's': modeTypeD, // channel is secret + 't': modeTypeD, // channel has protected topic +} + +type channelModes map[byte]string + +// applyChannelModes parses a mode string and mode arguments from a MODE message, +// and applies the corresponding channel mode and user membership changes on that channel. +// +// If ch.modes is nil, channel modes are not updated. +// +// needMarshaling is a list of indexes of mode arguments that represent entities +// that must be marshaled when sent downstream. +func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) { + needMarshaling = make(map[int]struct{}, len(arguments)) + nextArgument := 0 + var plusMinus byte +outer: + for i := 0; i < len(modeStr); i++ { + mode := modeStr[i] + if mode == '+' || mode == '-' { + plusMinus = mode + continue + } + if plusMinus != '+' && plusMinus != '-' { + return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr) + } + + for _, membership := range ch.conn.availableMemberships { + if membership.Mode == mode { + if nextArgument >= len(arguments) { + return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) + } + member := arguments[nextArgument] + m := ch.Members.Value(member) + if m != nil { + if plusMinus == '+' { + m.Add(ch.conn.availableMemberships, membership) + } else { + // TODO: for upstreams without multi-prefix, query the user modes again + m.Remove(membership) + } + } + needMarshaling[nextArgument] = struct{}{} + nextArgument++ + continue outer + } + } + + mt, ok := ch.conn.availableChannelModes[mode] + if !ok { + continue + } + if mt == modeTypeA { + nextArgument++ + } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') { + if plusMinus == '+' { + var argument string + // some sentitive arguments (such as channel keys) can be omitted for privacy + // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages) + if nextArgument < len(arguments) { + argument = arguments[nextArgument] + } + if ch.modes != nil { + ch.modes[mode] = argument + } + } else { + delete(ch.modes, mode) + } + nextArgument++ + } else if mt == modeTypeC || mt == modeTypeD { + if plusMinus == '+' { + if ch.modes != nil { + ch.modes[mode] = "" + } + } else { + delete(ch.modes, mode) + } + } + } + return needMarshaling, nil +} + +func (cm channelModes) Format() (modeString string, parameters []string) { + var modesWithValues strings.Builder + var modesWithoutValues strings.Builder + parameters = make([]string, 0, 16) + for mode, value := range cm { + if value != "" { + modesWithValues.WriteString(string(mode)) + parameters = append(parameters, value) + } else { + modesWithoutValues.WriteString(string(mode)) + } + } + modeString = "+" + modesWithValues.String() + modesWithoutValues.String() + return +} + +const stdChannelTypes = "#&+!" + +type channelStatus byte + +const ( + channelPublic channelStatus = '=' + channelSecret channelStatus = '@' + channelPrivate channelStatus = '*' +) + +func parseChannelStatus(s string) (channelStatus, error) { + if len(s) > 1 { + return 0, fmt.Errorf("invalid channel status %q: more than one character", s) + } + switch cs := channelStatus(s[0]); cs { + case channelPublic, channelSecret, channelPrivate: + return cs, nil + default: + return 0, fmt.Errorf("invalid channel status %q: unknown status", s) + } +} + +type membership struct { + Mode byte + Prefix byte +} + +var stdMemberships = []membership{ + {'q', '~'}, // founder + {'a', '&'}, // protected + {'o', '@'}, // operator + {'h', '%'}, // halfop + {'v', '+'}, // voice +} + +// memberships always sorted by descending membership rank +type memberships []membership + +func (m *memberships) Add(availableMemberships []membership, newMembership membership) { + l := *m + i := 0 + for _, availableMembership := range availableMemberships { + if i >= len(l) { + break + } + if l[i] == availableMembership { + if availableMembership == newMembership { + // we already have this membership + return + } + i++ + continue + } + if availableMembership == newMembership { + break + } + } + // insert newMembership at i + l = append(l, membership{}) + copy(l[i+1:], l[i:]) + l[i] = newMembership + *m = l +} + +func (m *memberships) Remove(oldMembership membership) { + l := *m + for i, currentMembership := range l { + if currentMembership == oldMembership { + *m = append(l[:i], l[i+1:]...) + return + } + } +} + +func (m memberships) Format(dc *downstreamConn) string { + if !dc.caps["multi-prefix"] { + if len(m) == 0 { + return "" + } + return string(m[0].Prefix) + } + prefixes := make([]byte, len(m)) + for i, membership := range m { + prefixes[i] = membership.Prefix + } + return string(prefixes) +} + +func parseMessageParams(msg *irc.Message, out ...*string) error { + if len(msg.Params) < len(out) { + return newNeedMoreParamsError(msg.Command) + } + for i := range out { + if out[i] != nil { + *out[i] = msg.Params[i] + } + } + return nil +} + +func copyClientTags(tags irc.Tags) irc.Tags { + t := make(irc.Tags, len(tags)) + for k, v := range tags { + if strings.HasPrefix(k, "+") { + t[k] = v + } + } + return t +} + +type batch struct { + Type string + Params []string + Outer *batch // if not-nil, this batch is nested in Outer + Label string +} + +func join(channels, keys []string) []*irc.Message { + // Put channels with a key first + js := joinSorter{channels, keys} + sort.Sort(&js) + + // Two spaces because there are three words (JOIN, channels and keys) + maxLength := maxMessageLength - (len("JOIN") + 2) + + var msgs []*irc.Message + var channelsBuf, keysBuf strings.Builder + for i, channel := range channels { + key := keys[i] + + n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel) + if key != "" { + n += 1 + len(key) + } + + if channelsBuf.Len() > 0 && n > maxLength { + // No room for the new channel in this message + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + channelsBuf.Reset() + keysBuf.Reset() + } + + if channelsBuf.Len() > 0 { + channelsBuf.WriteByte(',') + } + channelsBuf.WriteString(channel) + if key != "" { + if keysBuf.Len() > 0 { + keysBuf.WriteByte(',') + } + keysBuf.WriteString(key) + } + } + if channelsBuf.Len() > 0 { + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + } + + return msgs +} + +func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message { + maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text + + var msgs []*irc.Message + for len(tokens) > 0 { + var msgTokens []string + if len(tokens) > maxTokens { + msgTokens = tokens[:maxTokens] + tokens = tokens[maxTokens:] + } else { + msgTokens = tokens + tokens = nil + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ISUPPORT, + Params: append(append([]string{nick}, msgTokens...), "are supported"), + }) + } + + return msgs +} + +func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message { + var msgs []*irc.Message + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTDSTART, + Params: []string{nick, fmt.Sprintf("- Message of the Day -")}, + }) + + for _, l := range strings.Split(motd, "\n") { + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTD, + Params: []string{nick, l}, + }) + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ENDOFMOTD, + Params: []string{nick, "End of /MOTD command."}, + }) + + return msgs +} + +func generateMonitor(subcmd string, targets []string) []*irc.Message { + maxLength := maxMessageLength - len("MONITOR "+subcmd+" ") + + var msgs []*irc.Message + var buf []string + n := 0 + for _, target := range targets { + if n+len(target)+1 > maxLength { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + buf = buf[:0] + n = 0 + } + + buf = append(buf, target) + n += len(target) + 1 + } + + if len(buf) > 0 { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + } + + return msgs +} + +type joinSorter struct { + channels []string + keys []string +} + +func (js *joinSorter) Len() int { + return len(js.channels) +} + +func (js *joinSorter) Less(i, j int) bool { + if (js.keys[i] != "") != (js.keys[j] != "") { + // Only one of the channels has a key + return js.keys[i] != "" + } + return js.channels[i] < js.channels[j] +} + +func (js *joinSorter) Swap(i, j int) { + js.channels[i], js.channels[j] = js.channels[j], js.channels[i] + js.keys[i], js.keys[j] = js.keys[j], js.keys[i] +} + +// parseCTCPMessage parses a CTCP message. CTCP is defined in +// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02 +func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) { + if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 { + return "", "", false + } + text := msg.Params[1] + + if !strings.HasPrefix(text, "\x01") { + return "", "", false + } + text = strings.Trim(text, "\x01") + + words := strings.SplitN(text, " ", 2) + cmd = strings.ToUpper(words[0]) + if len(words) > 1 { + params = words[1] + } + + return cmd, params, true +} + +type casemapping func(string) string + +func casemapNone(name string) string { + return name +} + +// CasemapASCII of name is the canonical representation of name according to the +// ascii casemapping. +func casemapASCII(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } + } + return string(nameBytes) +} + +// casemapRFC1459 of name is the canonical representation of name according to the +// rfc1459 casemapping. +func casemapRFC1459(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } else if r == '~' { + nameBytes[i] = '^' + } + } + return string(nameBytes) +} + +// casemapRFC1459Strict of name is the canonical representation of name +// according to the rfc1459-strict casemapping. +func casemapRFC1459Strict(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } + } + return string(nameBytes) +} + +func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) { + switch tokenValue { + case "ascii": + casemap = casemapASCII + case "rfc1459": + casemap = casemapRFC1459 + case "rfc1459-strict": + casemap = casemapRFC1459Strict + default: + return nil, false + } + return casemap, true +} + +func partialCasemap(higher casemapping, name string) string { + nameFullyCM := []byte(higher(name)) + nameBytes := []byte(name) + for i, r := range nameBytes { + if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') { + nameBytes[i] = nameFullyCM[i] + } + } + return string(nameBytes) +} + +type casemapMap struct { + innerMap map[string]casemapEntry + casemap casemapping +} + +type casemapEntry struct { + originalKey string + value interface{} +} + +func newCasemapMap(size int) casemapMap { + return casemapMap{ + innerMap: make(map[string]casemapEntry, size), + casemap: casemapNone, + } +} + +func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return "", false + } + return entry.originalKey, true +} + +func (cm *casemapMap) Has(name string) bool { + _, ok := cm.innerMap[cm.casemap(name)] + return ok +} + +func (cm *casemapMap) Len() int { + return len(cm.innerMap) +} + +func (cm *casemapMap) SetValue(name string, value interface{}) { + nameCM := cm.casemap(name) + entry, ok := cm.innerMap[nameCM] + if !ok { + cm.innerMap[nameCM] = casemapEntry{ + originalKey: name, + value: value, + } + return + } + entry.value = value + cm.innerMap[nameCM] = entry +} + +func (cm *casemapMap) Delete(name string) { + delete(cm.innerMap, cm.casemap(name)) +} + +func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { + cm.casemap = newCasemap + newInnerMap := make(map[string]casemapEntry, len(cm.innerMap)) + for _, entry := range cm.innerMap { + newInnerMap[cm.casemap(entry.originalKey)] = entry + } + cm.innerMap = newInnerMap +} + +type upstreamChannelCasemapMap struct{ casemapMap } + +func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*upstreamChannel) +} + +type channelCasemapMap struct{ casemapMap } + +func (cm *channelCasemapMap) Value(name string) *Channel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*Channel) +} + +type membershipsCasemapMap struct{ casemapMap } + +func (cm *membershipsCasemapMap) Value(name string) *memberships { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*memberships) +} + +type deliveredCasemapMap struct{ casemapMap } + +func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(deliveredClientMap) +} + +type monitorCasemapMap struct{ casemapMap } + +func (cm *monitorCasemapMap) Value(name string) (online bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return false + } + return entry.value.(bool) +} + +func isWordBoundary(r rune) bool { + switch r { + case '-', '_', '|': // inspired from weechat.look.highlight_regex + return false + default: + return !unicode.IsLetter(r) && !unicode.IsNumber(r) + } +} + +func isHighlight(text, nick string) bool { + for { + i := strings.Index(text, nick) + if i < 0 { + return false + } + + left, _ := utf8.DecodeLastRuneInString(text[:i]) + right, _ := utf8.DecodeRuneInString(text[i+len(nick):]) + if isWordBoundary(left) && isWordBoundary(right) { + return true + } + + text = text[i+len(nick):] + } +} + +// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound. +// The zero time is returned on error. +func parseChatHistoryBound(param string) time.Time { + parts := strings.SplitN(param, "=", 2) + if len(parts) != 2 { + return time.Time{} + } + switch parts[0] { + case "timestamp": + timestamp, err := time.Parse(serverTimeLayout, parts[1]) + if err != nil { + return time.Time{} + } + return timestamp + default: + return time.Time{} + } +} + +// whoxFields is the list of all WHOX field letters, by order of appearance in +// RPL_WHOSPCRPL messages. +var whoxFields = []byte("tcuihsnfdlaor") + +type whoxInfo struct { + Token string + Username string + Hostname string + Server string + Nickname string + Flags string + Account string + Realname string +} + +func (info *whoxInfo) get(field byte) string { + switch field { + case 't': + return info.Token + case 'c': + return "*" + case 'u': + return info.Username + case 'i': + return "255.255.255.255" + case 'h': + return info.Hostname + case 's': + return info.Server + case 'n': + return info.Nickname + case 'f': + return info.Flags + case 'd': + return "0" + case 'l': // idle time + return "0" + case 'a': + account := "0" // WHOX uses "0" to mean "no account" + if info.Account != "" && info.Account != "*" { + account = info.Account + } + return account + case 'o': + return "0" + case 'r': + return info.Realname + } + return "" +} + +func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message { + if fields == "" { + return &irc.Message{ + Prefix: prefix, + Command: irc.RPL_WHOREPLY, + Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname}, + } + } + + fieldSet := make(map[byte]bool) + for i := 0; i < len(fields); i++ { + fieldSet[fields[i]] = true + } + + var values []string + for _, field := range whoxFields { + if !fieldSet[field] { + continue + } + values = append(values, info.get(field)) + } + + return &irc.Message{ + Prefix: prefix, + Command: rpl_whospcrpl, + Params: append([]string{nick}, values...), + } +} + +var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C") + +func encodeISUPPORT(s string) string { + return isupportEncoder.Replace(s) +} diff --git a/branches/origin-master/irc_test.go b/branches/origin-master/irc_test.go new file mode 100644 index 0000000..8757bc8 --- /dev/null +++ b/branches/origin-master/irc_test.go @@ -0,0 +1,34 @@ +package suika + +import ( + "testing" +) + +func TestIsHighlight(t *testing.T) { + nick := "SojuUser" + testCases := []struct { + name string + text string + hl bool + }{ + {"noContains", "hi there Soju User!", false}, + {"middle", "hi there SojuUser!", true}, + {"start", "SojuUser: how are you doing?", true}, + {"end", "maybe ask SojuUser", true}, + {"inWord", "but OtherSojuUserSan is a different nick", false}, + {"startWord", "and OtherSojuUser is another different nick", false}, + {"endWord", "and SojuUserSan is yet a different nick", false}, + {"underscore", "and SojuUser_san has nothing to do with me", false}, + {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false}, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + hl := isHighlight(tc.text, nick) + if hl != tc.hl { + t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl) + } + }) + } +} diff --git a/branches/origin-master/msgstore.go b/branches/origin-master/msgstore.go new file mode 100644 index 0000000..ed57429 --- /dev/null +++ b/branches/origin-master/msgstore.go @@ -0,0 +1,123 @@ +package suika + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +// messageStore is a per-user store for IRC messages. +type messageStore interface { + Close() error + // LastMsgID queries the last message ID for the given network, entity and + // date. The message ID returned may not refer to a valid message, but can be + // used in history queries. + LastMsgID(network *Network, entity string, t time.Time) (string, error) + // LoadLatestID queries the latest non-event messages for the given network, + // entity and date, up to a count of limit messages, sorted from oldest to newest. + LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) + Append(network *Network, entity string, msg *irc.Message) (id string, err error) +} + +type chatHistoryTarget struct { + Name string + LatestMessage time.Time +} + +// chatHistoryMessageStore is a message store that supports chat history +// operations. +type chatHistoryMessageStore interface { + messageStore + + // ListTargets lists channels and nicknames by time of the latest message. + // It returns up to limit targets, starting from start and ending on end, + // both excluded. end may be before or after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) + // LoadBeforeTime loads up to limit messages before start down to end. The + // returned messages must be between and excluding the provided bounds. + // end is before start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + // LoadBeforeTime loads up to limit messages after start up to end. The + // returned messages must be between and excluding the provided bounds. + // end is after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) +} + +type msgIDType uint + +const ( + msgIDNone msgIDType = iota + msgIDMemory + msgIDFS +) + +const msgIDVersion uint = 0 + +type msgIDHeader struct { + Version uint + Network bare.Int + Target string + Type msgIDType +} + +type msgIDBody interface { + msgIDType() msgIDType +} + +func formatMsgID(netID int64, target string, body msgIDBody) string { + var buf bytes.Buffer + w := bare.NewWriter(&buf) + + header := msgIDHeader{ + Version: msgIDVersion, + Network: bare.Int(netID), + Target: target, + Type: body.msgIDType(), + } + if err := bare.MarshalWriter(w, &header); err != nil { + panic(err) + } + if err := bare.MarshalWriter(w, body); err != nil { + panic(err) + } + return base64.RawURLEncoding.EncodeToString(buf.Bytes()) +} + +func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + r := bare.NewReader(bytes.NewReader(b)) + + var header msgIDHeader + if err := bare.UnmarshalBareReader(r, &header); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + if header.Version != msgIDVersion { + return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion) + } + + if body != nil { + typ := body.msgIDType() + if header.Type != typ { + return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ) + } + + if err := bare.UnmarshalBareReader(r, body); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + } + + return int64(header.Network), header.Target, nil +} diff --git a/branches/origin-master/msgstore_fs.go b/branches/origin-master/msgstore_fs.go new file mode 100644 index 0000000..90bd76a --- /dev/null +++ b/branches/origin-master/msgstore_fs.go @@ -0,0 +1,698 @@ +package suika + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const ( + fsMessageStoreMaxFiles = 20 + fsMessageStoreMaxTries = 100 +) + +func escapeFilename(unsafe string) (safe string) { + if unsafe == "." { + return "-" + } else if unsafe == ".." { + return "--" + } else { + return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe) + } +} + +type date struct { + Year, Month, Day int +} + +func newDate(t time.Time) date { + year, month, day := t.Date() + return date{year, int(month), day} +} + +func (d date) Time() time.Time { + return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local) +} + +type fsMsgID struct { + Date date + Offset bare.Int +} + +func (fsMsgID) msgIDType() msgIDType { + return msgIDFS +} + +func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) { + var id fsMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", time.Time{}, 0, err + } + return netID, entity, id.Date.Time(), int64(id.Offset), nil +} + +func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string { + id := fsMsgID{ + Date: newDate(t), + Offset: bare.Int(offset), + } + return formatMsgID(netID, entity, &id) +} + +type fsMessageStoreFile struct { + *os.File + lastUse time.Time +} + +// fsMessageStore is a per-user on-disk store for IRC messages. +// +// It mimicks the ZNC log layout and format. See the ZNC source: +// https://github.com/znc/znc/blob/master/modules/log.cpp +type fsMessageStore struct { + root string + user *User + + // Write-only files used by Append + files map[string]*fsMessageStoreFile // indexed by entity +} + +var _ messageStore = (*fsMessageStore)(nil) +var _ chatHistoryMessageStore = (*fsMessageStore)(nil) + +func newFSMessageStore(root string, user *User) *fsMessageStore { + return &fsMessageStore{ + root: filepath.Join(root, escapeFilename(user.Username)), + user: user, + files: make(map[string]*fsMessageStoreFile), + } +} + +func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string { + year, month, day := t.Date() + filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) + return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) +} + +// nextMsgID queries the message ID for the next message to be written to f. +func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) { + offset, err := f.Seek(0, io.SeekEnd) + if err != nil { + return "", fmt.Errorf("failed to query next FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, offset), nil +} + +func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + p := ms.logPath(network, entity, t) + fi, err := os.Stat(p) + if os.IsNotExist(err) { + return formatFSMsgID(network.ID, entity, t, -1), nil + } else if err != nil { + return "", fmt.Errorf("failed to query last FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil +} + +func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + s := formatMessage(msg) + if s == "" { + return "", nil + } + + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(serverTimeLayout, string(tag)) + if err != nil { + return "", fmt.Errorf("failed to parse message time tag: %v", err) + } + t = t.In(time.Local) + } else { + t = time.Now() + } + + f := ms.files[entity] + + // TODO: handle non-monotonic clock behaviour + path := ms.logPath(network, entity, t) + if f == nil || f.Name() != path { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0750); err != nil { + return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err) + } + + ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640) + if err != nil { + return "", fmt.Errorf("failed to open message log file %q: %v", path, err) + } + + if f != nil { + f.Close() + } + f = &fsMessageStoreFile{File: ff} + ms.files[entity] = f + } + + f.lastUse = time.Now() + + if len(ms.files) > fsMessageStoreMaxFiles { + entities := make([]string, 0, len(ms.files)) + for name := range ms.files { + entities = append(entities, name) + } + sort.Slice(entities, func(i, j int) bool { + a, b := entities[i], entities[j] + return ms.files[a].lastUse.Before(ms.files[b].lastUse) + }) + entities = entities[0 : len(entities)-fsMessageStoreMaxFiles] + for _, name := range entities { + ms.files[name].Close() + delete(ms.files, name) + } + } + + msgID, err := nextFSMsgID(network, entity, t, f.File) + if err != nil { + return "", fmt.Errorf("failed to generate message ID: %v", err) + } + + _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s) + if err != nil { + return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err) + } + + return msgID, nil +} + +func (ms *fsMessageStore) Close() error { + var closeErr error + for _, f := range ms.files { + if err := f.Close(); err != nil { + closeErr = fmt.Errorf("failed to close message store: %v", err) + } + } + return closeErr +} + +// formatMessage formats a message log line. It assumes a well-formed IRC +// message. +func formatMessage(msg *irc.Message) string { + switch strings.ToUpper(msg.Command) { + case "NICK": + return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0]) + case "JOIN": + return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host) + case "PART": + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "KICK": + nick := msg.Params[1] + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason) + case "QUIT": + var reason string + if len(msg.Params) > 0 { + reason = msg.Params[0] + } + return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "TOPIC": + var topic string + if len(msg.Params) > 1 { + topic = msg.Params[1] + } + return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic) + case "MODE": + return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " ")) + case "NOTICE": + return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1]) + case "PRIVMSG": + if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" { + return fmt.Sprintf("* %s %s", msg.Prefix.Name, params) + } else { + return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1]) + } + default: + return "" + } +} + +func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { + var hour, minute, second int + _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) + if err != nil { + return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err) + } + line = line[11:] + + var cmd string + var prefix *irc.Prefix + var params []string + if events && strings.HasPrefix(line, "*** ") { + parts := strings.SplitN(line[4:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + switch parts[0] { + case "Joins:", "Parts:", "Quits:": + args := strings.SplitN(parts[1], " ", 3) + if len(args) < 2 { + return nil, time.Time{}, nil + } + nick := args[0] + mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + maskParts := strings.SplitN(mask, "@", 2) + if len(maskParts) != 2 { + return nil, time.Time{}, nil + } + prefix = &irc.Prefix{ + Name: nick, + User: maskParts[0], + Host: maskParts[1], + } + var reason string + if len(args) > 2 { + reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")") + } + switch parts[0] { + case "Joins:": + cmd = "JOIN" + params = []string{entity} + case "Parts:": + cmd = "PART" + if reason != "" { + params = []string{entity, reason} + } else { + params = []string{entity} + } + case "Quits:": + cmd = "QUIT" + if reason != "" { + params = []string{reason} + } + } + default: + nick := parts[0] + rem := parts[1] + if r := strings.TrimPrefix(rem, "is now known as "); r != rem { + cmd = "NICK" + prefix = &irc.Prefix{ + Name: nick, + } + params = []string{r} + } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem { + args := strings.SplitN(r, " ", 2) + if len(args) != 2 { + return nil, time.Time{}, nil + } + cmd = "KICK" + prefix = &irc.Prefix{ + Name: args[0], + } + reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + params = []string{entity, nick} + if reason != "" { + params = append(params, reason) + } + } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem { + cmd = "TOPIC" + prefix = &irc.Prefix{ + Name: nick, + } + topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'") + params = []string{entity, topic} + } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem { + cmd = "MODE" + prefix = &irc.Prefix{ + Name: nick, + } + params = append([]string{entity}, strings.Split(r, " ")...) + } else { + return nil, time.Time{}, nil + } + } + } else { + var sender, text string + if strings.HasPrefix(line, "<") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[1:], "> ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "-") { + cmd = "NOTICE" + parts := strings.SplitN(line[1:], "- ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "* ") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[2:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01" + } else { + return nil, time.Time{}, nil + } + + prefix = &irc.Prefix{Name: sender} + if entity == sender { + // This is a direct message from a user to us. We don't store own + // our nickname in the logs, so grab it from the network settings. + // Not very accurate since this may not match our nick at the time + // the message was received, but we can't do a lot better. + entity = GetNick(ms.user, network) + } + params = []string{entity, text} + } + + year, month, day := ref.Date() + t := time.Date(year, month, day, hour, minute, second, 0, time.Local) + + msg := &irc.Message{ + Tags: map[string]irc.TagValue{ + "time": irc.TagValue(formatServerTime(t)), + }, + Prefix: prefix, + Command: cmd, + Params: params, + } + return msg, t, nil +} + +func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages before ref: %v", err) + } + defer f.Close() + + historyRing := make([]*irc.Message, limit) + cur := 0 + + sc := bufio.NewScanner(f) + + if afterOffset >= 0 { + if _, err := f.Seek(afterOffset, io.SeekStart); err != nil { + return nil, nil + } + sc.Scan() // skip till next newline + } + + for sc.Scan() { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(end) { + continue + } else if !t.Before(ref) { + break + } + + historyRing[cur%limit] = msg + cur++ + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err()) + } + + n := limit + if cur < limit { + n = cur + } + start := (cur - n + limit) % limit + + if start+n <= limit { // ring doesnt wrap + return historyRing[start : start+n], nil + } else { // ring wraps + history := make([]*irc.Message, n) + r := copy(history, historyRing[start:]) + copy(history[r:], historyRing[:n-r]) + return history, nil + } +} + +func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages after ref: %v", err) + } + defer f.Close() + + var history []*irc.Message + sc := bufio.NewScanner(f) + for sc.Scan() && len(history) < limit { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(ref) { + continue + } else if !t.Before(end) { + break + } + + history = append(history, msg) + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err()) + } + + return history, nil +} + +func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + history := make([]*irc.Message, limit) + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) { + buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + var history []*irc.Message + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) { + buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + history = append(history, buf...) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location()) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + return history, nil +} + +func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + var afterTime time.Time + var afterOffset int64 + if id != "" { + var idNet int64 + var idEntity string + var err error + idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id) + if err != nil { + return nil, err + } + if idNet != network.ID || idEntity != entity { + return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity") + } + } + + history := make([]*irc.Message, limit) + t := time.Now() + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) { + var offset int64 = -1 + if afterOffset >= 0 && truncateDay(t).Equal(afterTime) { + offset = afterOffset + } + + buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := t.Date() + t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { + start = start.In(time.Local) + end = end.In(time.Local) + rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) + root, err := os.Open(rootPath) + if os.IsNotExist(err) { + return nil, nil + } else if err != nil { + return nil, err + } + + // The returned targets are escaped, and there is no way to un-escape + // TODO: switch to ReadDir (Go 1.16+) + targetNames, err := root.Readdirnames(0) + root.Close() + if err != nil { + return nil, err + } + + var targets []chatHistoryTarget + for _, target := range targetNames { + // target is already escaped here + targetPath := filepath.Join(rootPath, target) + targetDir, err := os.Open(targetPath) + if err != nil { + return nil, err + } + + entries, err := targetDir.Readdir(0) + targetDir.Close() + if err != nil { + return nil, err + } + + // We use mtime here, which may give imprecise or incorrect results + var t time.Time + for _, entry := range entries { + if entry.ModTime().After(t) { + t = entry.ModTime() + } + } + + // The timestamps we get from logs have second granularity + t = truncateSecond(t) + + // Filter out targets that don't fullfil the time bounds + if !isTimeBetween(t, start, end) { + continue + } + + targets = append(targets, chatHistoryTarget{ + Name: target, + LatestMessage: t, + }) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + // Sort targets by latest message time, backwards or forwards depending on + // the order of the time bounds + sort.Slice(targets, func(i, j int) bool { + t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage + if start.Before(end) { + return t1.Before(t2) + } else { + return !t1.Before(t2) + } + }) + + // Truncate the result if necessary + if len(targets) > limit { + targets = targets[:limit] + } + + return targets, nil +} + +func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error { + oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) + newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) + // Avoid loosing data by overwriting an existing directory + if _, err := os.Stat(newDir); err == nil { + return fmt.Errorf("destination %q already exists", newDir) + } + return os.Rename(oldDir, newDir) +} + +func truncateDay(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, 0, 0, 0, 0, t.Location()) +} + +func truncateSecond(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location()) +} + +func isTimeBetween(t, start, end time.Time) bool { + if end.Before(start) { + end, start = start, end + } + return start.Before(t) && t.Before(end) +} diff --git a/branches/origin-master/msgstore_memory.go b/branches/origin-master/msgstore_memory.go new file mode 100644 index 0000000..25ab953 --- /dev/null +++ b/branches/origin-master/msgstore_memory.go @@ -0,0 +1,160 @@ +package suika + +import ( + "context" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const messageRingBufferCap = 4096 + +type memoryMsgID struct { + Seq bare.Uint +} + +func (memoryMsgID) msgIDType() msgIDType { + return msgIDMemory +} + +func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) { + var id memoryMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", 0, err + } + return netID, entity, uint64(id.Seq), nil +} + +func formatMemoryMsgID(netID int64, entity string, seq uint64) string { + id := memoryMsgID{bare.Uint(seq)} + return formatMsgID(netID, entity, &id) +} + +type ringBufferKey struct { + networkID int64 + entity string +} + +type memoryMessageStore struct { + buffers map[ringBufferKey]*messageRingBuffer +} + +var _ messageStore = (*memoryMessageStore)(nil) + +func newMemoryMessageStore() *memoryMessageStore { + return &memoryMessageStore{ + buffers: make(map[ringBufferKey]*messageRingBuffer), + } +} + +func (ms *memoryMessageStore) Close() error { + ms.buffers = nil + return nil +} + +func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer { + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + return rb + } + rb := newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + return rb +} + +func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + var seq uint64 + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + seq = rb.cur + } + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + switch msg.Command { + case "PRIVMSG", "NOTICE": + // Only append these messages, because LoadLatestID shouldn't return + // other kinds of message. + default: + return "", nil + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + rb = newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + } + + seq := rb.Append(msg) + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + _, _, seq, err := parseMemoryMsgID(id) + if err != nil { + return nil, err + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + return nil, nil + } + + return rb.LoadLatestSeq(seq, limit) +} + +type messageRingBuffer struct { + buf []*irc.Message + cur uint64 +} + +func newMessageRingBuffer(capacity int) *messageRingBuffer { + return &messageRingBuffer{ + buf: make([]*irc.Message, capacity), + cur: 1, + } +} + +func (rb *messageRingBuffer) cap() uint64 { + return uint64(len(rb.buf)) +} + +func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 { + seq := rb.cur + i := int(seq % rb.cap()) + rb.buf[i] = msg + rb.cur++ + return seq +} + +func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) { + if seq > rb.cur { + return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur) + } else if seq == rb.cur { + return nil, nil + } + + // The query excludes the message with the sequence number seq + diff := rb.cur - seq - 1 + if diff > rb.cap() { + // We dropped diff - cap entries + diff = rb.cap() + } + if int(diff) > limit { + diff = uint64(limit) + } + + l := make([]*irc.Message, int(diff)) + for i := 0; i < int(diff); i++ { + j := int((rb.cur - diff + uint64(i)) % rb.cap()) + l[i] = rb.buf[j] + } + + return l, nil +} diff --git a/branches/origin-master/net_go113.go b/branches/origin-master/net_go113.go new file mode 100644 index 0000000..24bde7f --- /dev/null +++ b/branches/origin-master/net_go113.go @@ -0,0 +1,12 @@ +//go:build !go1.16 +// +build !go1.16 + +package suika + +import ( + "strings" +) + +func isErrClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "use of closed network connection") +} diff --git a/branches/origin-master/net_go116.go b/branches/origin-master/net_go116.go new file mode 100644 index 0000000..ef1256a --- /dev/null +++ b/branches/origin-master/net_go116.go @@ -0,0 +1,13 @@ +//go:build go1.16 +// +build go1.16 + +package suika + +import ( + "errors" + "net" +) + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) +} diff --git a/branches/origin-master/rate.go b/branches/origin-master/rate.go new file mode 100644 index 0000000..a14bce2 --- /dev/null +++ b/branches/origin-master/rate.go @@ -0,0 +1,40 @@ +package suika + +import ( + "math/rand" + "time" +) + +// backoffer implements a simple exponential backoff. +type backoffer struct { + min, max, jitter time.Duration + n int64 +} + +func newBackoffer(min, max, jitter time.Duration) *backoffer { + return &backoffer{min: min, max: max, jitter: jitter} +} + +func (b *backoffer) Reset() { + b.n = 0 +} + +func (b *backoffer) Next() time.Duration { + if b.n == 0 { + b.n = 1 + return 0 + } + + d := time.Duration(b.n) * b.min + if d > b.max { + d = b.max + } else { + b.n *= 2 + } + + if b.jitter != 0 { + d += time.Duration(rand.Int63n(int64(b.jitter))) + } + + return d +} diff --git a/branches/origin-master/rc.d/freebsd-rc.d b/branches/origin-master/rc.d/freebsd-rc.d new file mode 100755 index 0000000..0fed078 --- /dev/null +++ b/branches/origin-master/rc.d/freebsd-rc.d @@ -0,0 +1,29 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +desc="A drunk IRC bouncer" +rcvar="suika_enable" + +: ${suika_user="ircd"} + +command="%%PREFIX%%/bin/suika" +pidfile="/var/run/suika.pid" +required_files="%%PREFIX%%/etc/suika/config" + +start_cmd="suika_start" + +suika_start() { + /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files} +} + +load_rc_config "$name" +run_rc_command "$1" diff --git a/branches/origin-master/rc.d/immortal.yml b/branches/origin-master/rc.d/immortal.yml new file mode 100644 index 0000000..c7a8090 --- /dev/null +++ b/branches/origin-master/rc.d/immortal.yml @@ -0,0 +1,3 @@ +# $TheSupernovaDuo$ +cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +user: ircd diff --git a/branches/origin-master/rc.d/netbsd-rc.d b/branches/origin-master/rc.d/netbsd-rc.d new file mode 100644 index 0000000..3fef470 --- /dev/null +++ b/branches/origin-master/rc.d/netbsd-rc.d @@ -0,0 +1,28 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +rcvar="${name}" +command="%%PREFIX/bin/${name}" +command_args="--config %%PREFIX%%/etc/suika/config" +pidfile="/var/run/${name}.pid" +start_cmd="${name}_start" + +suika_start() { + printf "Starting %s..." "${name}" + ${command} ${command_args} + pgrep -n ${name} > ${pidfile} +} + +load_rc_config ${name} +run_rc_command "$1" + + diff --git a/branches/origin-master/rc.d/openbsd-rc.d b/branches/origin-master/rc.d/openbsd-rc.d new file mode 100644 index 0000000..f0bf77c --- /dev/null +++ b/branches/origin-master/rc.d/openbsd-rc.d @@ -0,0 +1,12 @@ +#!/bin/ksh +# $TheSupernovaDuo$ +# vim: ft=sh + +daemon="%%PREFIX%%/bin/suika" +daemon_args="--config %%PREFIX%%/etc/suika/config" + +. /etc/rc.d/rc.subr + +rc_bg=YES + +rc_cmd "$1" diff --git a/branches/origin-master/rc.d/suika.service b/branches/origin-master/rc.d/suika.service new file mode 100644 index 0000000..d8d53b3 --- /dev/null +++ b/branches/origin-master/rc.d/suika.service @@ -0,0 +1,16 @@ +# $TheSupernovaDuo$ +# vim: ft=confini +[Unit] +Description=A drunk IRC bouncer +After=network.target +Wants=network.target +StartLimitBurst=5 +StartLimitIntervalSec=1 +[Service] +Type=simple +Restart=on-abnormal +RestartSec=1 +User=suika +ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +[Install] +WantedBy=multi-user.target diff --git a/branches/origin-master/server.go b/branches/origin-master/server.go new file mode 100644 index 0000000..05c55fd --- /dev/null +++ b/branches/origin-master/server.go @@ -0,0 +1,330 @@ +package suika + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "gopkg.in/irc.v3" +) + +// TODO: make configurable +var ( + retryConnectMinDelay = time.Minute + retryConnectMaxDelay = 10 * time.Minute + retryConnectJitter = time.Minute + connectTimeout = 15 * time.Second + writeTimeout = 10 * time.Second + upstreamMessageDelay = 2 * time.Second + upstreamMessageBurst = 10 + backlogTimeout = 10 * time.Second + handleDownstreamMessageTimeout = 10 * time.Second + downstreamRegisterTimeout = 30 * time.Second + chatHistoryLimit = 1000 + backlogLimit = 4000 +) + +type Logger interface { + Printf(format string, v ...interface{}) + Debugf(format string, v ...interface{}) +} + +type logger struct { + *log.Logger + debug bool +} + +func (l logger) Debugf(format string, v ...interface{}) { + if !l.debug { + return + } + l.Logger.Printf(format, v...) +} + +func NewLogger(out io.Writer, debug bool) Logger { + return logger{ + Logger: log.New(log.Writer(), "", log.LstdFlags), + debug: debug, + } +} + +type prefixLogger struct { + logger Logger + prefix string +} + +var _ Logger = (*prefixLogger)(nil) + +func (l *prefixLogger) Printf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Printf("%v"+format, v...) +} + +func (l *prefixLogger) Debugf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Debugf("%v"+format, v...) +} + +type int64Gauge struct { + v int64 // atomic +} + +func (g *int64Gauge) Add(delta int64) { + atomic.AddInt64(&g.v, delta) +} + +func (g *int64Gauge) Value() int64 { + return atomic.LoadInt64(&g.v) +} + +func (g *int64Gauge) Float64() float64 { + return float64(g.Value()) +} + +type retryListener struct { + net.Listener + Logger Logger + + delay time.Duration +} + +func (ln *retryListener) Accept() (net.Conn, error) { + for { + conn, err := ln.Listener.Accept() + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if ln.delay == 0 { + ln.delay = 5 * time.Millisecond + } else { + ln.delay *= 2 + } + if max := 1 * time.Second; ln.delay > max { + ln.delay = max + } + if ln.Logger != nil { + ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err) + } + time.Sleep(ln.delay) + } else { + ln.delay = 0 + return conn, err + } + } +} + +type Config struct { + Hostname string + Title string + LogPath string + MaxUserNetworks int + MultiUpstream bool + MOTD string + UpstreamUserIPs []*net.IPNet +} + +type Server struct { + Logger Logger + + config atomic.Value // *Config + db Database + stopWG sync.WaitGroup + + lock sync.Mutex + listeners map[net.Listener]struct{} + users map[string]*user +} + +func NewServer(db Database) *Server { + srv := &Server{ + Logger: NewLogger(log.Writer(), true), + db: db, + listeners: make(map[net.Listener]struct{}), + users: make(map[string]*user), + } + srv.config.Store(&Config{ + Hostname: "localhost", + MaxUserNetworks: -1, + MultiUpstream: true, + }) + return srv +} + +func (s *Server) prefix() *irc.Prefix { + return &irc.Prefix{Name: s.Config().Hostname} +} + +func (s *Server) Config() *Config { + return s.config.Load().(*Config) +} + +func (s *Server) SetConfig(cfg *Config) { + s.config.Store(cfg) +} + +func (s *Server) Start() error { + users, err := s.db.ListUsers(context.TODO()) + if err != nil { + return err + } + + s.lock.Lock() + for i := range users { + s.addUserLocked(&users[i]) + } + s.lock.Unlock() + + return nil +} + +func (s *Server) Shutdown() { + s.lock.Lock() + for ln := range s.listeners { + if err := ln.Close(); err != nil { + s.Logger.Printf("failed to stop listener: %v", err) + } + } + for _, u := range s.users { + u.events <- eventStop{} + } + s.lock.Unlock() + + s.stopWG.Wait() + + if err := s.db.Close(); err != nil { + s.Logger.Printf("failed to close DB: %v", err) + } +} + +func (s *Server) createUser(ctx context.Context, user *User) (*user, error) { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.users[user.Username]; ok { + return nil, fmt.Errorf("user %q already exists", user.Username) + } + + err := s.db.StoreUser(ctx, user) + if err != nil { + return nil, fmt.Errorf("could not create user in db: %v", err) + } + + return s.addUserLocked(user), nil +} + +func (s *Server) forEachUser(f func(*user)) { + s.lock.Lock() + for _, u := range s.users { + f(u) + } + s.lock.Unlock() +} + +func (s *Server) getUser(name string) *user { + s.lock.Lock() + u := s.users[name] + s.lock.Unlock() + return u +} + +func (s *Server) addUserLocked(user *User) *user { + s.Logger.Printf("starting bouncer for user %q", user.Username) + u := newUser(s, user) + s.users[u.Username] = u + + s.stopWG.Add(1) + + go func() { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack()) + } + + s.lock.Lock() + delete(s.users, u.Username) + s.lock.Unlock() + + s.stopWG.Done() + }() + + u.run() + }() + + return u +} + +var lastDownstreamID uint64 = 0 + +func (s *Server) handle(ic ircConn) { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack()) + } + }() + + id := atomic.AddUint64(&lastDownstreamID, 1) + dc := newDownstreamConn(s, ic, id) + if err := dc.runUntilRegistered(); err != nil { + if !errors.Is(err, io.EOF) { + dc.logger.Printf("%v", err) + } + } else { + dc.user.events <- eventDownstreamConnected{dc} + if err := dc.readMessages(dc.user.events); err != nil { + dc.logger.Printf("%v", err) + } + dc.user.events <- eventDownstreamDisconnected{dc} + } + dc.Close() +} + +func (s *Server) Serve(ln net.Listener) error { + ln = &retryListener{ + Listener: ln, + Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())}, + } + + s.lock.Lock() + s.listeners[ln] = struct{}{} + s.lock.Unlock() + + s.stopWG.Add(1) + + defer func() { + s.lock.Lock() + delete(s.listeners, ln) + s.lock.Unlock() + + s.stopWG.Done() + }() + + for { + conn, err := ln.Accept() + if isErrClosed(err) { + return nil + } else if err != nil { + return fmt.Errorf("failed to accept connection: %v", err) + } + + go s.handle(newNetIRCConn(conn)) + } +} + +type ServerStats struct { + Users int + Downstreams int64 + Upstreams int64 +} + +func (s *Server) Stats() *ServerStats { + var stats ServerStats + s.lock.Lock() + stats.Users = len(s.users) + s.lock.Unlock() + return &stats +} diff --git a/branches/origin-master/server_test.go b/branches/origin-master/server_test.go new file mode 100644 index 0000000..cf0adca --- /dev/null +++ b/branches/origin-master/server_test.go @@ -0,0 +1,207 @@ +package suika + +import ( + "context" + "net" + "testing" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +var testServerPrefix = &irc.Prefix{Name: "suika-test-server"} + +const ( + testUsername = "suika-test-user" + testPassword = testUsername +) + +func createTempSqliteDB(t *testing.T) Database { + db, err := OpenDB("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + // :memory: will open a separate database for each new connection. Make + // sure the sql package only uses a single connection. An alternative + // solution is to use "file::memory:?cache=shared". + db.(*SqliteDB).db.SetMaxOpenConns(1) + return db +} + +func createTempPostgresDB(t *testing.T) Database { + db := &PostgresDB{db: openTempPostgresDB(t)} + if err := db.upgrade(); err != nil { + t.Fatalf("failed to upgrade PostgreSQL database: %v", err) + } + + return db +} + +func createTestUser(t *testing.T, db Database) *User { + hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("failed to generate bcrypt hash: %v", err) + } + + record := &User{Username: testUsername, Password: string(hashed)} + if err := db.StoreUser(context.Background(), record); err != nil { + t.Fatalf("failed to store test user: %v", err) + } + + return record +} + +func createTestDownstream(t *testing.T, srv *Server) ircConn { + c1, c2 := net.Pipe() + go srv.handle(newNetIRCConn(c1)) + return newNetIRCConn(c2) +} + +func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to create TCP listener: %v", err) + } + + network := &Network{ + Name: "testnet", + Addr: "irc://" + ln.Addr().String(), + Nick: user.Username, + Enabled: true, + } + if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil { + t.Fatalf("failed to store test network: %v", err) + } + + return network, ln +} + +func mustAccept(t *testing.T, ln net.Listener) ircConn { + c, err := ln.Accept() + if err != nil { + t.Fatalf("failed accepting connection: %v", err) + } + return newNetIRCConn(c) +} + +func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message { + msg, err := c.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message (want %q): %v", cmd, err) + } + if msg.Command != cmd { + t.Fatalf("invalid message received: want %q, got: %v", cmd, msg) + } + return msg +} + +func registerDownstreamConn(t *testing.T, c ircConn, network *Network) { + c.WriteMessage(&irc.Message{ + Command: "PASS", + Params: []string{testPassword}, + }) + c.WriteMessage(&irc.Message{ + Command: "NICK", + Params: []string{testUsername}, + }) + c.WriteMessage(&irc.Message{ + Command: "USER", + Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername}, + }) + + expectMessage(t, c, irc.RPL_WELCOME) +} + +func registerUpstreamConn(t *testing.T, c ircConn) { + msg := expectMessage(t, c, "CAP") + if msg.Params[0] != "LS" { + t.Fatalf("invalid CAP LS: got: %v", msg) + } + msg = expectMessage(t, c, "NICK") + nick := msg.Params[0] + if nick != testUsername { + t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg) + } + expectMessage(t, c, "USER") + + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_WELCOME, + Params: []string{nick, "Welcome!"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_YOURHOST, + Params: []string{nick, "Your host is suika-test-server"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_CREATED, + Params: []string{nick, "Who cares when the server was created?"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_MYINFO, + Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.ERR_NOMOTD, + Params: []string{nick, "No MOTD"}, + }) +} + +func testServer(t *testing.T, db Database) { + user := createTestUser(t, db) + network, upstream := createTestUpstream(t, db, user) + defer upstream.Close() + + srv := NewServer(db) + if err := srv.Start(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer srv.Shutdown() + + uc := mustAccept(t, upstream) + defer uc.Close() + registerUpstreamConn(t, uc) + + dc := createTestDownstream(t, srv) + defer dc.Close() + registerDownstreamConn(t, dc, network) + + noticeText := "This is a very important server notice." + uc.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: "NOTICE", + Params: []string{testUsername, noticeText}, + }) + + var msg *irc.Message + for { + var err error + msg, err = dc.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message: %v", err) + } + if msg.Command == "NOTICE" { + break + } + } + + if msg.Params[1] != noticeText { + t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg) + } +} + +func TestServer(t *testing.T) { + t.Run("sqlite", func(t *testing.T) { + db := createTempSqliteDB(t) + testServer(t, db) + }) + + t.Run("postgres", func(t *testing.T) { + db := createTempPostgresDB(t) + testServer(t, db) + }) +} diff --git a/branches/origin-master/service.go b/branches/origin-master/service.go new file mode 100644 index 0000000..07226ef --- /dev/null +++ b/branches/origin-master/service.go @@ -0,0 +1,1150 @@ +package suika + +import ( + "context" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "flag" + "fmt" + "io/ioutil" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +const ( + serviceNick = "BouncerServ" + serviceNickCM = "bouncerserv" + serviceRealname = "suika bouncer service" +) + +// maxRSABits is the maximum number of RSA key bits used when generating a new +// private key. +const maxRSABits = 8192 + +var servicePrefix = &irc.Prefix{ + Name: serviceNick, + User: serviceNick, + Host: serviceNick, +} + +type serviceCommandSet map[string]*serviceCommand + +type serviceCommand struct { + usage string + desc string + handle func(ctx context.Context, dc *downstreamConn, params []string) error + children serviceCommandSet + admin bool +} + +func sendServiceNOTICE(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{dc.nick, text}, + }) +} + +func sendServicePRIVMSG(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "PRIVMSG", + Params: []string{dc.nick, text}, + }) +} + +func splitWords(s string) ([]string, error) { + var words []string + var lastWord strings.Builder + escape := false + prev := ' ' + wordDelim := ' ' + + for _, r := range s { + if escape { + // last char was a backslash, write the byte as-is. + lastWord.WriteRune(r) + escape = false + } else if r == '\\' { + escape = true + } else if wordDelim == ' ' && unicode.IsSpace(r) { + // end of last word + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + lastWord.Reset() + } + } else if r == wordDelim { + // wordDelim is either " or ', switch back to + // space-delimited words. + wordDelim = ' ' + } else if r == '"' || r == '\'' { + if wordDelim == ' ' { + // start of (double-)quoted word + wordDelim = r + } else { + // either wordDelim is " and r is ' or vice-versa + lastWord.WriteRune(r) + } + } else { + lastWord.WriteRune(r) + } + + prev = r + } + + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + } + + if wordDelim != ' ' { + return nil, fmt.Errorf("unterminated quoted string") + } + if escape { + return nil, fmt.Errorf("unterminated backslash sequence") + } + + return words, nil +} + +func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) { + words, err := splitWords(text) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err)) + return + } + + cmd, params, err := serviceCommands.Get(words) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err)) + return + } + if cmd.admin && !dc.user.Admin { + sendServicePRIVMSG(dc, "error: you must be an admin to use this command") + return + } + + if cmd.handle == nil { + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + // Pretend the command does not exist if it has neither children nor handler. + // This is obviously a bug but it is better to not die anyway. + dc.logger.Printf("command without handler and subcommands invoked:", words[0]) + sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0])) + } + return + } + + if err := cmd.handle(ctx, dc, params); err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err)) + } +} + +func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) { + if len(params) == 0 { + return nil, nil, fmt.Errorf("no command specified") + } + + name := params[0] + params = params[1:] + + cmd, ok := cmds[name] + if !ok { + for k := range cmds { + if !strings.HasPrefix(k, name) { + continue + } + if cmd != nil { + return nil, params, fmt.Errorf("command %q is ambiguous", name) + } + cmd = cmds[k] + } + } + if cmd == nil { + return nil, params, fmt.Errorf("command %q not found", name) + } + + if len(params) == 0 || len(cmd.children) == 0 { + return cmd, params, nil + } + return cmd.children.Get(params) +} + +func (cmds serviceCommandSet) Names() []string { + l := make([]string, 0, len(cmds)) + for name := range cmds { + l = append(l, name) + } + sort.Strings(l) + return l +} + +var serviceCommands serviceCommandSet + +func init() { + serviceCommands = serviceCommandSet{ + "help": { + usage: "[command]", + desc: "print help message", + handle: handleServiceHelp, + }, + "network": { + children: serviceCommandSet{ + "create": { + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "add a new network", + handle: handleServiceNetworkCreate, + }, + "status": { + desc: "show a list of saved networks and their current status", + handle: handleServiceNetworkStatus, + }, + "update": { + usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "update a network", + handle: handleServiceNetworkUpdate, + }, + "delete": { + usage: "[name]", + desc: "delete a network", + handle: handleServiceNetworkDelete, + }, + "quote": { + usage: "[name] ", + desc: "send a raw line to a network", + handle: handleServiceNetworkQuote, + }, + }, + }, + "certfp": { + children: serviceCommandSet{ + "generate": { + usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]", + desc: "generate a new self-signed certificate, defaults to using RSA-3072 key", + handle: handleServiceCertFPGenerate, + }, + "fingerprint": { + usage: "[-network name]", + desc: "show fingerprints of certificate", + handle: handleServiceCertFPFingerprints, + }, + }, + }, + "sasl": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show SASL status", + handle: handleServiceSASLStatus, + }, + "set-plain": { + usage: "[-network name] ", + desc: "set SASL PLAIN credentials", + handle: handleServiceSASLSetPlain, + }, + "reset": { + usage: "[-network name]", + desc: "disable SASL authentication and remove stored credentials", + handle: handleServiceSASLReset, + }, + }, + }, + "user": { + children: serviceCommandSet{ + "create": { + usage: "-username -password [-realname ] [-admin]", + desc: "create a new suika user", + handle: handleUserCreate, + admin: true, + }, + "update": { + usage: "[-password ] [-realname ]", + desc: "update the current user", + handle: handleUserUpdate, + }, + "delete": { + usage: "", + desc: "delete a user", + handle: handleUserDelete, + admin: true, + }, + }, + }, + "channel": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show a list of saved channels and their current status", + handle: handleServiceChannelStatus, + }, + "update": { + usage: " [-relay-detached ] [-reattach-on ] [-detach-after ] [-detach-on ]", + desc: "update a channel", + handle: handleServiceChannelUpdate, + }, + }, + }, + "server": { + children: serviceCommandSet{ + "status": { + desc: "show server statistics", + handle: handleServiceServerStatus, + admin: true, + }, + "notice": { + desc: "broadcast a notice to all connected bouncer users", + handle: handleServiceServerNotice, + admin: true, + }, + }, + admin: true, + }, + } +} + +func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) { + for _, name := range cmds.Names() { + cmd := cmds[name] + if cmd.admin && !admin { + continue + } + words := append(prefix, name) + if len(cmd.children) == 0 { + s := strings.Join(words, " ") + *l = append(*l, s) + } else { + appendServiceCommandSetHelp(cmd.children, words, admin, l) + } + } +} + +func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) > 0 { + cmd, rest, err := serviceCommands.Get(params) + if err != nil { + return err + } + words := params[:len(params)-len(rest)] + + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + text := strings.Join(words, " ") + if cmd.usage != "" { + text += " " + cmd.usage + } + text += ": " + cmd.desc + + sendServicePRIVMSG(dc, text) + } + } else { + var l []string + appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } + return nil +} + +func newFlagSet() *flag.FlagSet { + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.SetOutput(ioutil.Discard) + return fs +} + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +// stringPtrFlag is a flag value populating a string pointer. This allows to +// disambiguate between a flag that hasn't been set and a flag that has been +// set to an empty string. +type stringPtrFlag struct { + ptr **string +} + +func (f stringPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return **f.ptr +} + +func (f stringPtrFlag) Set(s string) error { + *f.ptr = &s + return nil +} + +type boolPtrFlag struct { + ptr **bool +} + +func (f boolPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return strconv.FormatBool(**f.ptr) +} + +func (f boolPtrFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + *f.ptr = &v + return nil +} + +func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) { + name, params := popArg(params) + if name == "" { + if dc.network == nil { + return nil, params, fmt.Errorf("no network selected, a name argument is required") + } + return dc.network, params, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, params, fmt.Errorf("unknown network %q", name) + } + return net, params, nil + } +} + +type networkFlagSet struct { + *flag.FlagSet + Addr, Name, Nick, Username, Pass, Realname *string + Enabled *bool + ConnectCommands []string +} + +func newNetworkFlagSet() *networkFlagSet { + fs := &networkFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.Addr}, "addr", "") + fs.Var(stringPtrFlag{&fs.Name}, "name", "") + fs.Var(stringPtrFlag{&fs.Nick}, "nick", "") + fs.Var(stringPtrFlag{&fs.Username}, "username", "") + fs.Var(stringPtrFlag{&fs.Pass}, "pass", "") + fs.Var(stringPtrFlag{&fs.Realname}, "realname", "") + fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "") + fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "") + return fs +} + +func (fs *networkFlagSet) update(network *Network) error { + if fs.Addr != nil { + if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 { + scheme := addrParts[0] + switch scheme { + case "ircs", "irc", "unix": + default: + return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme) + } + } + network.Addr = *fs.Addr + } + if fs.Name != nil { + network.Name = *fs.Name + } + if fs.Nick != nil { + network.Nick = *fs.Nick + } + if fs.Username != nil { + network.Username = *fs.Username + } + if fs.Pass != nil { + network.Pass = *fs.Pass + } + if fs.Realname != nil { + network.Realname = *fs.Realname + } + if fs.Enabled != nil { + network.Enabled = *fs.Enabled + } + if fs.ConnectCommands != nil { + if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" { + network.ConnectCommands = nil + } else { + for _, command := range fs.ConnectCommands { + _, err := irc.ParseMessage(command) + if err != nil { + return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) + } + } + network.ConnectCommands = fs.ConnectCommands + } + } + return nil +} + +func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + if fs.Addr == nil { + return fmt.Errorf("flag -addr is required") + } + + record := &Network{ + Addr: *fs.Addr, + Enabled: true, + } + if err := fs.update(record); err != nil { + return err + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return fmt.Errorf("could not create network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName())) + return nil +} + +func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error { + n := 0 + for _, net := range dc.user.networks { + var statuses []string + var details string + if uc := net.conn; uc != nil { + if dc.nick != uc.nick { + statuses = append(statuses, "connected as "+uc.nick) + } else { + statuses = append(statuses, "connected") + } + details = fmt.Sprintf("%v channels", uc.channels.Len()) + } else if !net.Enabled { + statuses = append(statuses, "disabled") + } else { + statuses = append(statuses, "disconnected") + if net.lastError != nil { + details = net.lastError.Error() + } + } + + if net == dc.network { + statuses = append(statuses, "current") + } + + name := net.GetName() + if name != net.Addr { + name = fmt.Sprintf("%v (%v)", name, net.Addr) + } + + s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", ")) + if details != "" { + s += ": " + details + } + sendServicePRIVMSG(dc, s) + + n++ + } + + if n == 0 { + sendServicePRIVMSG(dc, `No network configured, add one with "network create".`) + } + + return nil +} + +func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + + record := net.Network // copy network record because we'll mutate it + if err := fs.update(&record); err != nil { + return err + } + + network, err := dc.user.updateNetwork(ctx, &record) + if err != nil { + return fmt.Errorf("could not update network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName())) + return nil +} + +func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName())) + return nil +} + +func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 && len(params) != 2 { + return fmt.Errorf("expected one or two arguments") + } + + raw := params[len(params)-1] + params = params[:len(params)-1] + + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + uc := net.conn + if uc == nil { + return fmt.Errorf("network %q is not currently connected", net.GetName()) + } + + m, err := irc.ParseMessage(raw) + if err != nil { + return fmt.Errorf("failed to parse command %q: %v", raw, err) + } + uc.SendMessage(ctx, m) + + sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName())) + return nil +} + +func sendCertfpFingerprints(dc *downstreamConn, cert []byte) { + sha1Sum := sha1.Sum(cert) + sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:])) + sha256Sum := sha256.Sum256(cert) + sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:])) + sha512Sum := sha512.Sum512(cert) + sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:])) +} + +func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) { + if name == "" { + if dc.network == nil { + return nil, fmt.Errorf("no network selected, -network is required") + } + return dc.network, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, fmt.Errorf("unknown network %q", name) + } + return net, nil + } +} + +func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)") + bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA") + + if err := fs.Parse(params); err != nil { + return err + } + + if *bits <= 0 || *bits > maxRSABits { + return fmt.Errorf("invalid value for -bits") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + privKey, cert, err := generateCertFP(*keyType, *bits) + if err != nil { + return err + } + + net.SASL.External.CertBlob = cert + net.SASL.External.PrivKeyBlob = privKey + net.SASL.Mechanism = "EXTERNAL" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "certificate generated") + sendCertfpFingerprints(dc, cert) + return nil +} + +func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + if net.SASL.Mechanism != "EXTERNAL" { + return fmt.Errorf("CertFP not set up") + } + + sendCertfpFingerprints(dc, net.SASL.External.CertBlob) + return nil +} + +func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + switch net.SASL.Mechanism { + case "PLAIN": + sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username)) + case "EXTERNAL": + sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled") + case "": + sendServicePRIVMSG(dc, "SASL is disabled") + } + + if uc := net.conn; uc != nil { + if uc.account != "" { + sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account)) + } else { + sendServicePRIVMSG(dc, "Unauthenticated on upstream network") + } + } else { + sendServicePRIVMSG(dc, "Disconnected from upstream network") + } + + return nil +} + +func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + if len(fs.Args()) != 2 { + return fmt.Errorf("expected exactly 2 arguments") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = fs.Arg(0) + net.SASL.Plain.Password = fs.Arg(1) + net.SASL.Mechanism = "PLAIN" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials saved") + return nil +} + +func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = "" + net.SASL.Plain.Password = "" + net.SASL.External.CertBlob = nil + net.SASL.External.PrivKeyBlob = nil + net.SASL.Mechanism = "" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials reset") + return nil +} + +func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + username := fs.String("username", "", "") + password := fs.String("password", "", "") + realname := fs.String("realname", "", "") + admin := fs.Bool("admin", false, "") + + if err := fs.Parse(params); err != nil { + return err + } + if *username == "" { + return fmt.Errorf("flag -username is required") + } + if *password == "" { + return fmt.Errorf("flag -password is required") + } + + hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + + user := &User{ + Username: *username, + Password: string(hashed), + Realname: *realname, + Admin: *admin, + } + if _, err := dc.srv.createUser(ctx, user); err != nil { + return fmt.Errorf("could not create user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username)) + return nil +} + +func popArg(params []string) (string, []string) { + if len(params) > 0 && !strings.HasPrefix(params[0], "-") { + return params[0], params[1:] + } + return "", params +} + +func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + var password, realname *string + var admin *bool + fs := newFlagSet() + fs.Var(stringPtrFlag{&password}, "password", "") + fs.Var(stringPtrFlag{&realname}, "realname", "") + fs.Var(boolPtrFlag{&admin}, "admin", "") + + username, params := popArg(params) + if err := fs.Parse(params); err != nil { + return err + } + if len(fs.Args()) > 0 { + return fmt.Errorf("unexpected argument") + } + + var hashed *string + if password != nil { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + hashedStr := string(hashedBytes) + hashed = &hashedStr + } + + if username != "" && username != dc.user.Username { + if !dc.user.Admin { + return fmt.Errorf("you must be an admin to update other users") + } + if realname != nil { + return fmt.Errorf("cannot update -realname of other user") + } + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + done := make(chan error, 1) + event := eventUserUpdate{ + password: hashed, + admin: admin, + done: done, + } + select { + case <-ctx.Done(): + return ctx.Err() + case u.events <- event: + } + // TODO: send context to the other side + if err := <-done; err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username)) + } else { + // copy the user record because we'll mutate it + record := dc.user.User + + if hashed != nil { + record.Password = *hashed + } + if realname != nil { + record.Realname = *realname + } + if admin != nil { + return fmt.Errorf("cannot update -admin of own user") + } + + if err := dc.user.updateUser(ctx, &record); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username)) + } + + return nil +} + +func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + username := params[0] + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + u.stop() + + if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil { + return fmt.Errorf("failed to delete user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username)) + return nil +} + +func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error { + var defaultNetworkName string + if dc.network != nil { + defaultNetworkName = dc.network.GetName() + } + + fs := newFlagSet() + networkName := fs.String("network", defaultNetworkName, "") + + if err := fs.Parse(params); err != nil { + return err + } + + n := 0 + + sendNetwork := func(net *network) { + var channels []*Channel + for _, entry := range net.channels.innerMap { + channels = append(channels, entry.value.(*Channel)) + } + + sort.Slice(channels, func(i, j int) bool { + return strings.ReplaceAll(channels[i].Name, "#", "") < + strings.ReplaceAll(channels[j].Name, "#", "") + }) + + for _, ch := range channels { + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + } + + name := ch.Name + if *networkName == "" { + name += "/" + net.GetName() + } + + var status string + if uch != nil { + status = "joined" + } else if net.conn != nil { + status = "parted" + } else { + status = "disconnected" + } + + if ch.Detached { + status += ", detached" + } + + s := fmt.Sprintf("%v [%v]", name, status) + sendServicePRIVMSG(dc, s) + + n++ + } + } + + if *networkName == "" { + for _, net := range dc.user.networks { + sendNetwork(net) + } + } else { + net := dc.user.getNetwork(*networkName) + if net == nil { + return fmt.Errorf("unknown network %q", *networkName) + } + sendNetwork(net) + } + + if n == 0 { + sendServicePRIVMSG(dc, "No channel configured.") + } + + return nil +} + +type channelFlagSet struct { + *flag.FlagSet + RelayDetached, ReattachOn, DetachAfter, DetachOn *string +} + +func newChannelFlagSet() *channelFlagSet { + fs := &channelFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "") + fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "") + fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "") + fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "") + return fs +} + +func (fs *channelFlagSet) update(channel *Channel) error { + if fs.RelayDetached != nil { + filter, err := parseFilter(*fs.RelayDetached) + if err != nil { + return err + } + channel.RelayDetached = filter + } + if fs.ReattachOn != nil { + filter, err := parseFilter(*fs.ReattachOn) + if err != nil { + return err + } + channel.ReattachOn = filter + } + if fs.DetachAfter != nil { + dur, err := time.ParseDuration(*fs.DetachAfter) + if err != nil || dur < 0 { + return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter) + } + channel.DetachAfter = dur + } + if fs.DetachOn != nil { + filter, err := parseFilter(*fs.DetachOn) + if err != nil { + return err + } + channel.DetachOn = filter + } + return nil +} + +func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) < 1 { + return fmt.Errorf("expected at least one argument") + } + name := params[0] + + fs := newChannelFlagSet() + if err := fs.Parse(params[1:]); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return fmt.Errorf("unknown channel %q", name) + } + + ch := uc.network.channels.Value(upstreamName) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + + if err := fs.update(ch); err != nil { + return err + } + + uc.updateChannelAutoDetach(upstreamName) + + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + return fmt.Errorf("failed to update channel: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name)) + return nil +} +func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error { + dbStats, err := dc.user.srv.db.Stats(ctx) + if err != nil { + return err + } + serverStats := dc.user.srv.Stats() + sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels)) + return nil +} + +func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + text := params[0] + + dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text) + + broadcastMsg := &irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{"$" + dc.srv.Config().Hostname, text}, + } + var err error + sent := 0 + total := 0 + dc.srv.forEachUser(func(u *user) { + total++ + select { + case <-ctx.Done(): + err = ctx.Err() + case u.events <- eventBroadcast{broadcastMsg}: + sent++ + } + }) + + dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total) + sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total)) + + return err +} diff --git a/branches/origin-master/service_test.go b/branches/origin-master/service_test.go new file mode 100644 index 0000000..1f3fe6b --- /dev/null +++ b/branches/origin-master/service_test.go @@ -0,0 +1,54 @@ +package suika + +import ( + "testing" +) + +func assertSplit(t *testing.T, input string, expected []string) { + actual, err := splitWords(input) + if err != nil { + t.Errorf("%q: %v", input, err) + return + } + if len(actual) != len(expected) { + t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual) + return + } + for i := 0; i < len(actual); i++ { + if actual[i] != expected[i] { + t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual) + } + } +} + +func TestSplit(t *testing.T) { + assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{ + "ch", + "up", + "#suika", + "relay-detached", + "message", + }) + assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{ + "net", + "update", + "\"free\"node", + "-pass", + "political \"stance\" desu!", + "-realname", + "", + "-nick", + "lee", + }) + assertSplit(t, "Omedeto,\\ Yui! ''", []string{ + "Omedeto, Yui!", + "", + }) + + if _, err := splitWords("end of 'file"); err == nil { + t.Errorf("expected error on unterminated single quote") + } + if _, err := splitWords("end of backquote \\"); err == nil { + t.Errorf("expected error on unterminated backquote sequence") + } +} diff --git a/branches/origin-master/suika_psql_schema.sql b/branches/origin-master/suika_psql_schema.sql new file mode 100644 index 0000000..d148b4d --- /dev/null +++ b/branches/origin-master/suika_psql_schema.sql @@ -0,0 +1,58 @@ +CREATE TABLE IF NOT EXISTS "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + +CREATE TABLE IF NOT EXISTS "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255), + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism sasl_mechanism, + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA, + sasl_external_key BYTEA, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); +CREATE TABLE IF NOT EXISTS "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); +CREATE TABLE IF NOT EXISTS "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) +); + diff --git a/branches/origin-master/suika_sqlite_schema.sql b/branches/origin-master/suika_sqlite_schema.sql new file mode 100644 index 0000000..517d06f --- /dev/null +++ b/branches/origin-master/suika_sqlite_schema.sql @@ -0,0 +1,61 @@ +CREATE TABLE IF NOT EXISTS User ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password TEXT, + admin INTEGER NOT NULL DEFAULT 0, + realname TEXT +); +CREATE TABLE IF NOT EXISTS Network ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); +CREATE TABLE IF NOT EXISTS Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name TEXT NOT NULL, + key TEXT, + detached INTEGER NOT NULL DEFAULT 0, + detached_internal_msgid TEXT, + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + client TEXT, + internal_msgid TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) +); + +CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) +); + diff --git a/branches/origin-master/upstream.go b/branches/origin-master/upstream.go new file mode 100644 index 0000000..285dca4 --- /dev/null +++ b/branches/origin-master/upstream.go @@ -0,0 +1,2182 @@ +package suika + +import ( + "context" + "crypto" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "gopkg.in/irc.v3" +) + +// permanentUpstreamCaps is the static list of upstream capabilities always +// requested when supported. +var permanentUpstreamCaps = map[string]bool{ + "account-notify": true, + "account-tag": true, + "away-notify": true, + "batch": true, + "extended-join": true, + "invite-notify": true, + "labeled-response": true, + "message-tags": true, + "multi-prefix": true, + "sasl": true, + "server-time": true, + "setname": true, + + "draft/account-registration": true, + "draft/extended-monitor": true, +} + +type registrationError struct { + *irc.Message +} + +func (err registrationError) Error() string { + return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason()) +} + +func (err registrationError) Reason() string { + if len(err.Params) > 0 { + return err.Params[len(err.Params)-1] + } + return err.Command +} + +func (err registrationError) Temporary() bool { + // Only return false if we're 100% sure that fixing the error requires a + // network configuration change + switch err.Command { + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME: + return false + case "FAIL": + return err.Params[1] != "ACCOUNT_REQUIRED" + default: + return true + } +} + +type upstreamChannel struct { + Name string + conn *upstreamConn + Topic string + TopicWho *irc.Prefix + TopicTime time.Time + Status channelStatus + modes channelModes + creationTime string + Members membershipsCasemapMap + complete bool + detachTimer *time.Timer +} + +func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) { + if uc.detachTimer != nil { + uc.detachTimer.Stop() + uc.detachTimer = nil + } + + if dur == 0 { + return + } + + uc.detachTimer = time.AfterFunc(dur, func() { + uc.conn.network.user.events <- eventChannelDetach{ + uc: uc.conn, + name: uc.Name, + } + }) +} + +type pendingUpstreamCommand struct { + downstreamID uint64 + msg *irc.Message +} + +type upstreamConn struct { + conn + + network *network + user *user + + serverName string + availableUserModes string + availableChannelModes map[byte]channelModeType + availableChannelTypes string + availableMemberships []membership + isupport map[string]*string + + registered bool + nick string + nickCM string + username string + realname string + modes userModes + channels upstreamChannelCasemapMap + supportedCaps map[string]string + caps map[string]bool + batches map[string]batch + away bool + account string + nextLabelID uint64 + monitored monitorCasemapMap + + saslClient sasl.Client + saslStarted bool + + casemapIsSet bool + + // Queue of commands in progress, indexed by type. The first entry has been + // sent to the server and is awaiting reply. The following entries have not + // been sent yet. + pendingCmds map[string][]pendingUpstreamCommand + + gotMotd bool +} + +func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) { + logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())} + + dialer := net.Dialer{Timeout: connectTimeout} + + u, err := network.URL() + if err != nil { + return nil, err + } + + var netConn net.Conn + switch u.Scheme { + case "ircs": + addr := u.Host + host, _, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Host + addr = u.Host + ":6697" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to TLS server at address %q", addr) + + tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}} + if network.SASL.Mechanism == "EXTERNAL" { + if network.SASL.External.CertBlob == nil { + return nil, fmt.Errorf("missing certificate for authentication") + } + if network.SASL.External.PrivKeyBlob == nil { + return nil, fmt.Errorf("missing private key for authentication") + } + key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + tlsConfig.Certificates = []tls.Certificate{ + { + Certificate: [][]byte{network.SASL.External.CertBlob}, + PrivateKey: key.(crypto.PrivateKey), + }, + } + logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) + } + + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + + // Don't do the TLS handshake immediately, because we need to register + // the new connection with identd ASAP. + netConn = tls.Client(netConn, tlsConfig) + case "irc": + addr := u.Host + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = u.Host + addr = u.Host + ":6667" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to plain-text server at address %q", addr) + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + case "irc+unix", "unix": + logger.Printf("connecting to Unix socket at path %q", u.Path) + netConn, err = dialer.DialContext(ctx, "unix", u.Path) + if err != nil { + return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err) + } + default: + return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme) + } + + options := connOptions{ + Logger: logger, + RateLimitDelay: upstreamMessageDelay, + RateLimitBurst: upstreamMessageBurst, + } + + uc := &upstreamConn{ + conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), + network: network, + user: network.user, + channels: upstreamChannelCasemapMap{newCasemapMap(0)}, + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + batches: make(map[string]batch), + availableChannelTypes: stdChannelTypes, + availableChannelModes: stdChannelModes, + availableMemberships: stdMemberships, + isupport: make(map[string]*string), + pendingCmds: make(map[string][]pendingUpstreamCommand), + monitored: monitorCasemapMap{newCasemapMap(0)}, + } + return uc, nil +} + +func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { + uc.network.forEachDownstream(f) +} + +func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) { + uc.forEachDownstream(func(dc *downstreamConn) { + if id != 0 && id != dc.id { + return + } + f(dc) + }) +} + +func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn { + for _, dc := range uc.user.downstreamConns { + if dc.id == id { + return dc + } + } + return nil +} + +func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { + ch := uc.channels.Value(name) + if ch == nil { + return nil, fmt.Errorf("unknown channel %q", name) + } + return ch, nil +} + +func (uc *upstreamConn) isChannel(entity string) bool { + return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0])) +} + +func (uc *upstreamConn) isOurNick(nick string) bool { + return uc.nickCM == uc.network.casemap(nick) +} + +func (uc *upstreamConn) abortPendingCommands() { + for _, l := range uc.pendingCmds { + for _, pendingCmd := range l { + dc := uc.downstreamByID(pendingCmd.downstreamID) + if dc == nil { + continue + } + + switch pendingCmd.msg.Command { + case "LIST": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Command aborted"}, + }) + case "WHO": + mask := "*" + if len(pendingCmd.msg.Params) > 0 { + mask = pendingCmd.msg.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "Command aborted"}, + }) + case "AUTHENTICATE": + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + case "REGISTER", "VERIFY": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "FAIL", + Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"}, + }) + default: + panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command)) + } + } + } + + uc.pendingCmds = make(map[string][]pendingUpstreamCommand) +} + +func (uc *upstreamConn) sendNextPendingCommand(cmd string) { + if len(uc.pendingCmds[cmd]) == 0 { + return + } + uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg) +} + +func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { + switch msg.Command { + case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY": + // Supported + default: + panic(fmt.Errorf("Unsupported pending command %q", msg.Command)) + } + + uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{ + downstreamID: dc.id, + msg: msg, + }) + + if len(uc.pendingCmds[msg.Command]) == 1 { + uc.sendNextPendingCommand(msg.Command) + } +} + +func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) { + if len(uc.pendingCmds[cmd]) == 0 { + return nil, nil + } + + pendingCmd := uc.pendingCmds[cmd][0] + return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg +} + +func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) { + dc, msg := uc.currentPendingCommand(cmd) + + if len(uc.pendingCmds[cmd]) > 0 { + copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:]) + uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1] + } + + uc.sendNextPendingCommand(cmd) + + return dc, msg +} + +func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) { + for cmd := range uc.pendingCmds { + // We can't cancel the currently running command stored in + // uc.pendingCmds[cmd][0] + for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- { + if uc.pendingCmds[cmd][i].downstreamID == downstreamID { + uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...) + } + } + } +} + +func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) { + memberships := make(memberships, 0, 4) + i := 0 + for _, m := range uc.availableMemberships { + if i >= len(s) { + break + } + if s[i] == m.Prefix { + memberships = append(memberships, m) + i++ + } + } + return &memberships, s[i:] +} + +func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + var label string + if l, ok := msg.GetTag("label"); ok { + label = l + delete(msg.Tags, "label") + } + + var msgBatch *batch + if batchName, ok := msg.GetTag("batch"); ok { + b, ok := uc.batches[batchName] + if !ok { + return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName) + } + msgBatch = &b + if label == "" { + label = msgBatch.Label + } + delete(msg.Tags, "batch") + } + + var downstreamID uint64 = 0 + if label != "" { + var labelOffset uint64 + n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset) + if err == nil && n < 2 { + err = errors.New("not enough arguments") + } + if err != nil { + return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err) + } + } + + if _, ok := msg.Tags["time"]; !ok { + msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now())) + } + + switch msg.Command { + case "PING": + uc.SendMessage(ctx, &irc.Message{ + Command: "PONG", + Params: msg.Params, + }) + return nil + case "NOTICE", "PRIVMSG", "TAGMSG": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var entity, text string + if msg.Command != "TAGMSG" { + if err := parseMessageParams(msg, &entity, &text); err != nil { + return err + } + } else { + if err := parseMessageParams(msg, &entity); err != nil { + return err + } + } + + if msg.Prefix.Name == serviceNick { + uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg) + break + } + if entity == serviceNick { + uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg) + break + } + + if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message + uc.produce("", msg, nil) + } else { // regular user message + target := entity + if uc.isOurNick(target) { + target = msg.Prefix.Name + } + + ch := uc.network.channels.Value(target) + if ch != nil && msg.Command != "TAGMSG" { + if ch.Detached { + uc.handleDetachedMessage(ctx, ch, msg) + } + + highlight := uc.network.isHighlight(msg) + if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) { + uc.updateChannelAutoDetach(target) + } + } + + uc.produce(target, msg, nil) + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, nil, &subCmd); err != nil { + return err + } + subCmd = strings.ToUpper(subCmd) + subParams := msg.Params[2:] + switch subCmd { + case "LS": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := subParams[len(subParams)-1] + more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*" + + uc.handleSupportedCaps(caps) + + if more { + break // wait to receive all capabilities + } + + uc.requestCaps() + + if uc.requestSASL() { + break // we'll send CAP END after authentication is completed + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + case "ACK", "NAK": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, name := range caps { + if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil { + return err + } + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + case "NEW": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + uc.handleSupportedCaps(subParams[0]) + uc.requestCaps() + case "DEL": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, c := range caps { + delete(uc.supportedCaps, c) + delete(uc.caps, c) + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + default: + uc.logger.Printf("unhandled message: %v", msg) + } + case "AUTHENTICATE": + if uc.saslClient == nil { + return fmt.Errorf("received unexpected AUTHENTICATE message") + } + + // TODO: if a challenge is 400 bytes long, buffer it + var challengeStr string + if err := parseMessageParams(msg, &challengeStr); err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + var challenge []byte + if challengeStr != "+" { + var err error + challenge, err = base64.StdEncoding.DecodeString(challengeStr) + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + } + + var resp []byte + var err error + if !uc.saslStarted { + _, resp, err = uc.saslClient.Start() + uc.saslStarted = true + } else { + resp, err = uc.saslClient.Next(challenge) + } + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + // <= instead of < because we need to send a final empty response if + // the last chunk is exactly 400 bytes long + for i := 0; i <= len(resp); i += maxSASLLength { + j := i + maxSASLLength + if j > len(resp) { + j = len(resp) + } + + chunk := resp[i:j] + + var respStr = "+" + if len(chunk) != 0 { + respStr = base64.StdEncoding.EncodeToString(chunk) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{respStr}, + }) + } + case irc.RPL_LOGGEDIN: + if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil { + return err + } + uc.logger.Printf("logged in with account %q", uc.account) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.RPL_LOGGEDOUT: + uc.account = "" + uc.logger.Printf("logged out") + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED: + var info string + if err := parseMessageParams(msg, nil, &info); err != nil { + return err + } + switch msg.Command { + case irc.ERR_NICKLOCKED: + uc.logger.Printf("invalid nick used with SASL authentication: %v", info) + case irc.ERR_SASLFAIL: + uc.logger.Printf("SASL authentication failed: %v", info) + case irc.ERR_SASLTOOLONG: + uc.logger.Printf("SASL message too long: %v", info) + } + + uc.saslClient = nil + uc.saslStarted = false + + if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil { + if msg.Command == irc.RPL_SASLSUCCESS { + uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword) + } + + dc.endSASL(msg) + } + + if !uc.registered { + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + } + case "REGISTER", "VERIFY": + if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil { + if msg.Command == "REGISTER" { + var account, password string + if err := parseMessageParams(msg, nil, &account); err != nil { + return err + } + if err := parseMessageParams(cmd, nil, nil, &password); err != nil { + return err + } + uc.network.autoSaveSASLPlain(ctx, account, password) + } + + dc.SendMessage(msg) + } + case irc.RPL_WELCOME: + if err := parseMessageParams(msg, &uc.nick); err != nil { + return err + } + + uc.registered = true + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("connection registered with nick %q", uc.nick) + + if uc.network.channels.Len() > 0 { + var channels, keys []string + for _, entry := range uc.network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, ch.Name) + keys = append(keys, ch.Key) + } + + for _, msg := range join(channels, keys) { + uc.SendMessage(ctx, msg) + } + } + case irc.RPL_MYINFO: + if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil { + return err + } + case irc.RPL_ISUPPORT: + if err := parseMessageParams(msg, nil, nil); err != nil { + return err + } + + var downstreamIsupport []string + for _, token := range msg.Params[1 : len(msg.Params)-1] { + parameter := token + var negate, hasValue bool + var value string + if strings.HasPrefix(token, "-") { + negate = true + token = token[1:] + } else if i := strings.IndexByte(token, '='); i >= 0 { + parameter = token[:i] + value = token[i+1:] + hasValue = true + } + + if hasValue { + uc.isupport[parameter] = &value + } else if !negate { + uc.isupport[parameter] = nil + } else { + delete(uc.isupport, parameter) + } + + var err error + switch parameter { + case "CASEMAPPING": + casemap, ok := parseCasemappingToken(value) + if !ok { + casemap = casemapRFC1459 + } + uc.network.updateCasemapping(casemap) + uc.nickCM = uc.network.casemap(uc.nick) + uc.casemapIsSet = true + case "CHANMODES": + if !negate { + err = uc.handleChanModes(value) + } else { + uc.availableChannelModes = stdChannelModes + } + case "CHANTYPES": + if !negate { + uc.availableChannelTypes = value + } else { + uc.availableChannelTypes = stdChannelTypes + } + case "PREFIX": + if !negate { + err = uc.handleMemberships(value) + } else { + uc.availableMemberships = stdMemberships + } + } + if err != nil { + return err + } + + if passthroughIsupport[parameter] { + downstreamIsupport = append(downstreamIsupport, token) + } + } + + uc.updateMonitor() + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil { + return + } + msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport) + for _, msg := range msgs { + dc.SendMessage(msg) + } + }) + case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD: + if !uc.casemapIsSet { + // upstream did not send any CASEMAPPING token, thus + // we assume it implements the old RFCs with rfc1459. + uc.casemapIsSet = true + uc.network.updateCasemapping(casemapRFC1459) + uc.nickCM = uc.network.casemap(uc.nick) + } + + if !uc.gotMotd { + // Ignore the initial MOTD upon connection, but forward + // subsequent MOTD messages downstream + uc.gotMotd = true + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case "BATCH": + var tag string + if err := parseMessageParams(msg, &tag); err != nil { + return err + } + + if strings.HasPrefix(tag, "+") { + tag = tag[1:] + if _, ok := uc.batches[tag]; ok { + return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag) + } + var batchType string + if err := parseMessageParams(msg, nil, &batchType); err != nil { + return err + } + label := label + if label == "" && msgBatch != nil { + label = msgBatch.Label + } + uc.batches[tag] = batch{ + Type: batchType, + Params: msg.Params[2:], + Outer: msgBatch, + Label: label, + } + } else if strings.HasPrefix(tag, "-") { + tag = tag[1:] + if _, ok := uc.batches[tag]; !ok { + return fmt.Errorf("unknown BATCH reference tag: %q", tag) + } + delete(uc.batches, tag) + } else { + return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag) + } + case "NICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newNick string + if err := parseMessageParams(msg, &newNick); err != nil { + return err + } + + me := false + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick) + me = true + uc.nick = newNick + uc.nickCM = uc.network.casemap(uc.nick) + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + memberships := ch.Members.Value(msg.Prefix.Name) + if memberships != nil { + ch.Members.Delete(msg.Prefix.Name) + ch.Members.SetValue(newNick, memberships) + uc.appendLog(ch.Name, msg) + } + } + + if !me { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateNick() + }) + uc.updateMonitor() + } + case "SETNAME": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newRealname string + if err := parseMessageParams(msg, &newRealname); err != nil { + return err + } + + // TODO: consider appending this message to logs + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname) + uc.realname = newRealname + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateRealname() + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case "JOIN": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("joined channel %q", ch) + members := membershipsCasemapMap{newCasemapMap(0)} + members.casemap = uc.network.casemap + uc.channels.SetValue(ch, &upstreamChannel{ + Name: ch, + conn: uc, + Members: members, + }) + uc.updateChannelAutoDetach(ch) + + uc.SendMessage(ctx, &irc.Message{ + Command: "MODE", + Params: []string{ch}, + }) + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.SetValue(msg.Prefix.Name, &memberships{}) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "PART": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("parted channel %q", ch) + uch := uc.channels.Value(ch) + if uch != nil { + uc.channels.Delete(ch) + uch.updateAutoDetach(0) + } + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.Delete(msg.Prefix.Name) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "KICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channel, user string + if err := parseMessageParams(msg, &channel, &user); err != nil { + return err + } + + if uc.isOurNick(user) { + uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) + uc.channels.Delete(channel) + } else { + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + ch.Members.Delete(user) + } + + uc.produce(channel, msg, nil) + case "QUIT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("quit") + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if ch.Members.Has(msg.Prefix.Name) { + ch.Members.Delete(msg.Prefix.Name) + + uc.appendLog(ch.Name, msg) + } + } + + if msg.Prefix.Name != uc.nick { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case irc.RPL_TOPIC, irc.RPL_NOTOPIC: + var name, topic string + if err := parseMessageParams(msg, nil, &name, &topic); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if msg.Command == irc.RPL_TOPIC { + ch.Topic = topic + } else { + ch.Topic = "" + } + case "TOPIC": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if len(msg.Params) > 1 { + ch.Topic = msg.Params[1] + ch.TopicWho = msg.Prefix.Copy() + ch.TopicTime = time.Now() // TODO use msg.Tags["time"] + } else { + ch.Topic = "" + } + uc.produce(ch.Name, msg, nil) + case "MODE": + var name, modeStr string + if err := parseMessageParams(msg, &name, &modeStr); err != nil { + return err + } + + if !uc.isChannel(name) { // user mode change + if name != uc.nick { + return fmt.Errorf("received MODE message for unknown nick %q", name) + } + + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + } else { // channel mode change + ch, err := uc.getChannel(name) + if err != nil { + return err + } + + needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:]) + if err != nil { + return err + } + + uc.appendLog(ch.Name, msg) + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + params := make([]string, len(msg.Params)) + params[0] = dc.marshalEntity(uc.network, name) + params[1] = modeStr + + copy(params[2:], msg.Params[2:]) + for i, modeParam := range params[2:] { + if _, ok := needMarshaling[i]; ok { + params[2+i] = dc.marshalEntity(uc.network, modeParam) + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "MODE", + Params: params, + }) + }) + } + } + case irc.RPL_UMODEIS: + if err := parseMessageParams(msg, nil); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + uc.modes = "" + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + case irc.RPL_CHANNELMODEIS: + var channel string + if err := parseMessageParams(msg, nil, &channel); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 2 { + modeStr = msg.Params[2] + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstMode := ch.modes == nil + ch.modes = make(map[byte]string) + if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil { + return err + } + + c := uc.network.channels.Value(channel) + if firstMode && (c == nil || !c.Detached) { + modeStr, modeParams := ch.modes.Format() + + uc.forEachDownstream(func(dc *downstreamConn) { + params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + }) + } + case rpl_creationtime: + var channel, creationTime string + if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstCreationTime := ch.creationTime == "" + ch.creationTime = creationTime + + c := uc.network.channels.Value(channel) + if firstCreationTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime}, + }) + }) + } + case rpl_topicwhotime: + var channel, who, timeStr string + if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstTopicWhoTime := ch.TopicWho == nil + ch.TopicWho = irc.ParsePrefix(who) + sec, err := strconv.ParseInt(timeStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse topic time: %v", err) + } + ch.TopicTime = time.Unix(sec, 0) + + c := uc.network.channels.Value(channel) + if firstTopicWhoTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{ + dc.nick, + dc.marshalEntity(uc.network, ch.Name), + topicWho.String(), + timeStr, + }, + }) + }) + } + case irc.RPL_LIST: + var channel, clients, topic string + if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LIST, + Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic}, + }) + case irc.RPL_LISTEND: + dc, cmd := uc.dequeueCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "End of /LIST"}, + }) + case irc.RPL_NAMREPLY: + var name, statusStr, members string + if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + members := splitSpace(members) + for i, member := range members { + memberships, nick := uc.parseMembershipPrefix(member) + members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick) + } + memberStr := strings.Join(members, " ") + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, statusStr, channel, memberStr}, + }) + }) + return nil + } + + status, err := parseChannelStatus(statusStr) + if err != nil { + return err + } + ch.Status = status + + for _, s := range splitSpace(members) { + memberships, nick := uc.parseMembershipPrefix(s) + ch.Members.SetValue(nick, memberships) + } + case irc.RPL_ENDOFNAMES: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, channel, "End of /NAMES list"}, + }) + }) + return nil + } + + if ch.complete { + return fmt.Errorf("received unexpected RPL_ENDOFNAMES") + } + ch.complete = true + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + forwardChannel(ctx, dc, ch) + }) + } + case irc.RPL_WHOREPLY: + var channel, username, host, server, nick, flags, trailing string + if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO") + } else if dc == nil { + return nil + } + + if channel != "*" { + channel = dc.marshalEntity(uc.network, channel) + } + nick = dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOREPLY, + Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing}, + }) + case rpl_whospcrpl: + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO") + } else if dc == nil { + return nil + } + + // Only supported in single-upstream mode, so forward as-is + dc.SendMessage(msg) + case irc.RPL_ENDOFWHO: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + dc, cmd := uc.dequeueCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO") + } else if dc == nil { + return nil + } + + mask := "*" + if len(cmd.Params) > 0 { + mask = cmd.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "End of /WHO list"}, + }) + case irc.RPL_WHOISUSER: + var nick, username, host, realname string + if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, nick, username, host, "*", realname}, + }) + }) + case irc.RPL_WHOISSERVER: + var nick, server, serverInfo string + if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, nick, server, serverInfo}, + }) + }) + case irc.RPL_WHOISOPERATOR: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, nick, "is an IRC operator"}, + }) + }) + case irc.RPL_WHOISIDLE: + var nick string + if err := parseMessageParams(msg, nil, &nick, nil); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + params := []string{dc.nick, nick} + params = append(params, msg.Params[2:]...) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISIDLE, + Params: params, + }) + }) + case irc.RPL_WHOISCHANNELS: + var nick, channelList string + if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil { + return err + } + channels := splitSpace(channelList) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + channelList := make([]string, len(channels)) + for i, channel := range channels { + prefix, channel := uc.parseMembershipPrefix(channel) + channel = dc.marshalEntity(uc.network, channel) + channelList[i] = prefix.Format(dc) + channel + } + channels := strings.Join(channelList, " ") + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISCHANNELS, + Params: []string{dc.nick, nick, channels}, + }) + }) + case irc.RPL_ENDOFWHOIS: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, nick, "End of /WHOIS list"}, + }) + }) + case "INVITE": + var nick, channel string + if err := parseMessageParams(msg, &nick, &channel); err != nil { + return err + } + + weAreInvited := uc.isOurNick(nick) + + uc.forEachDownstream(func(dc *downstreamConn) { + if !weAreInvited && !dc.caps["invite-notify"] { + return + } + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "INVITE", + Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_INVITING: + var nick, channel string + if err := parseMessageParams(msg, nil, &nick, &channel); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_INVITING, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: + var targetsStr string + if err := parseMessageParams(msg, nil, &targetsStr); err != nil { + return err + } + targets := strings.Split(targetsStr, ",") + + online := msg.Command == irc.RPL_MONONLINE + for _, target := range targets { + prefix := irc.ParsePrefix(target) + uc.monitored.SetValue(prefix.Name, online) + } + + // Check if the nick we want is now free + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if !online && uc.nickCM != wantNickCM { + found := false + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if uc.network.casemap(prefix.Name) == wantNickCM { + found = true + break + } + } + if found { + uc.logger.Printf("desired nick %q is now available", wantNick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{wantNick}, + }) + } + } + + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if dc.monitored.Has(prefix.Name) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, target}, + }) + } + } + }) + case irc.ERR_MONLISTFULL: + var limit, targetsStr string + if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil { + return err + } + + targets := strings.Split(targetsStr, ",") + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + if dc.monitored.Has(target) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, limit, target}, + }) + } + } + }) + case irc.RPL_AWAY: + var nick, reason string + if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_AWAY, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason}, + }) + }) + case "AWAY", "ACCOUNT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST: + var channel, mask string + if err := parseMessageParams(msg, nil, &channel, &mask); err != nil { + return err + } + var addNick, addTime string + if len(msg.Params) >= 5 { + addNick = msg.Params[3] + addTime = msg.Params[4] + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, channel) + + var params []string + if addNick != "" && addTime != "" { + addNick := dc.marshalEntity(uc.network, addNick) + params = []string{dc.nick, channel, mask, addNick, addTime} + } else { + params = []string{dc.nick, channel, mask} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST: + var channel, trailing string + if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + upstreamChannel := dc.marshalEntity(uc.network, channel) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, upstreamChannel, trailing}, + }) + }) + case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN: + var command, reason string + if err := parseMessageParams(msg, nil, &command, &reason); err != nil { + return err + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, command, reason}, + }) + }) + case "FAIL": + var command, code string + if err := parseMessageParams(msg, &command, &code); err != nil { + return err + } + + if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" { + return registrationError{msg} + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case "ACK": + // Ignore + case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: + // Ignore + case irc.RPL_YOURHOST, irc.RPL_CREATED: + // Ignore + case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: + fallthrough + case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: + fallthrough + case rpl_localusers, rpl_globalusers: + fallthrough + case irc.RPL_MOTDSTART, irc.RPL_MOTD: + // Ignore these messages if they're part of the initial registration + // message burst. Forward them if the user explicitly asked for them. + if !uc.gotMotd { + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_LISTSTART: + // Ignore + case "ERROR": + var text string + if err := parseMessageParams(msg, &text); err != nil { + return err + } + return fmt.Errorf("fatal server error: %v", text) + case irc.ERR_NICKNAMEINUSE: + // At this point, we haven't received ISUPPORT so we don't know the + // maximum nickname length or whether the server supports MONITOR. Many + // servers have NICKLEN=30 so let's just use that. + if !uc.registered && len(uc.nick)+1 < 30 { + uc.nick = uc.nick + "_" + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + return nil + } + fallthrough + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP: + if !uc.registered { + return registrationError{msg} + } + fallthrough + default: + uc.logger.Printf("unhandled message: %v", msg) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + // best effort marshaling for unknown messages, replies and errors: + // most numerics start with the user nick, marshal it if that's the case + // otherwise, conservately keep the params without marshaling + params := msg.Params + if _, err := strconv.Atoi(msg.Command); err == nil { // numeric + if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) { + params[0] = dc.nick + } + } + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + } + return nil +} + +func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) { + if uc.network.detachedMessageNeedsRelay(ch, msg) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.relayDetachedMessage(uc.network, msg) + }) + } + if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { + uc.network.attach(ctx, ch) + if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) + } + } +} + +func (uc *upstreamConn) handleChanModes(s string) error { + parts := strings.SplitN(s, ",", 5) + if len(parts) < 4 { + return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s) + } + modes := make(map[byte]channelModeType) + for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} { + for j := 0; j < len(parts[i]); j++ { + mode := parts[i][j] + modes[mode] = mt + } + } + uc.availableChannelModes = modes + return nil +} + +func (uc *upstreamConn) handleMemberships(s string) error { + if s == "" { + uc.availableMemberships = nil + return nil + } + + if s[0] != '(' { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + sep := strings.IndexByte(s, ')') + if sep < 0 || len(s) != sep*2 { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + memberships := make([]membership, len(s)/2-1) + for i := range memberships { + memberships[i] = membership{ + Mode: s[i+1], + Prefix: s[sep+i+1], + } + } + uc.availableMemberships = memberships + return nil +} + +func (uc *upstreamConn) handleSupportedCaps(capsStr string) { + caps := strings.Fields(capsStr) + for _, s := range caps { + kv := strings.SplitN(s, "=", 2) + k := strings.ToLower(kv[0]) + var v string + if len(kv) == 2 { + v = kv[1] + } + uc.supportedCaps[k] = v + } +} + +func (uc *upstreamConn) requestCaps() { + var requestCaps []string + for c := range permanentUpstreamCaps { + if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { + requestCaps = append(requestCaps, c) + } + } + + if len(requestCaps) == 0 { + return + } + + uc.SendMessage(context.TODO(), &irc.Message{ + Command: "CAP", + Params: []string{"REQ", strings.Join(requestCaps, " ")}, + }) +} + +func (uc *upstreamConn) supportsSASL(mech string) bool { + v, ok := uc.supportedCaps["sasl"] + if !ok { + return false + } + + if v == "" { + return true + } + + mechanisms := strings.Split(v, ",") + for _, mech := range mechanisms { + if strings.EqualFold(mech, mech) { + return true + } + } + return false +} + +func (uc *upstreamConn) requestSASL() bool { + if uc.network.SASL.Mechanism == "" { + return false + } + return uc.supportsSASL(uc.network.SASL.Mechanism) +} + +func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error { + uc.caps[name] = ok + + switch name { + case "sasl": + if !uc.requestSASL() { + return nil + } + if !ok { + uc.logger.Printf("server refused to acknowledge the SASL capability") + return nil + } + + auth := &uc.network.SASL + switch auth.Mechanism { + case "PLAIN": + uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) + uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password) + case "EXTERNAL": + uc.logger.Printf("starting SASL EXTERNAL authentication") + uc.saslClient = sasl.NewExternalClient("") + default: + return fmt.Errorf("unsupported SASL mechanism %q", name) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{auth.Mechanism}, + }) + default: + if permanentUpstreamCaps[name] { + break + } + uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name) + } + return nil +} + +func splitSpace(s string) []string { + return strings.FieldsFunc(s, func(r rune) bool { + return r == ' ' + }) +} + +func (uc *upstreamConn) register(ctx context.Context) { + uc.nick = GetNick(&uc.user.User, &uc.network.Network) + uc.nickCM = uc.network.casemap(uc.nick) + uc.username = GetUsername(&uc.user.User, &uc.network.Network) + uc.realname = GetRealname(&uc.user.User, &uc.network.Network) + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"LS", "302"}, + }) + + if uc.network.Pass != "" { + uc.SendMessage(ctx, &irc.Message{ + Command: "PASS", + Params: []string{uc.network.Pass}, + }) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + uc.SendMessage(ctx, &irc.Message{ + Command: "USER", + Params: []string{uc.username, "0", "*", uc.realname}, + }) +} + +func (uc *upstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := uc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error { + for !uc.registered { + msg, err := uc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read message: %v", err) + } + + if err := uc.handleMessage(ctx, msg); err != nil { + if _, ok := err.(registrationError); ok { + return err + } else { + msg.Tags = nil // prevent message tags from cluttering logs + return fmt.Errorf("failed to handle message %q: %v", msg, err) + } + } + } + + for _, command := range uc.network.ConnectCommands { + m, err := irc.ParseMessage(command) + if err != nil { + uc.logger.Printf("failed to parse connect command %q: %v", command, err) + } else { + uc.SendMessage(ctx, m) + } + } + + return nil +} + +func (uc *upstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := uc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventUpstreamMessage{msg, uc} + } + + return nil +} + +func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { + if !uc.caps["message-tags"] { + msg = msg.Copy() + msg.Tags = nil + } + + uc.conn.SendMessage(ctx, msg) +} + +func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { + if uc.caps["labeled-response"] { + if msg.Tags == nil { + msg.Tags = make(map[string]irc.TagValue) + } + msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID)) + uc.nextLabelID++ + } + uc.SendMessage(ctx, msg) +} + +// appendLog appends a message to the log file. +// +// The internal message ID is returned. If the message isn't recorded in the +// log file, an empty string is returned. +func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) { + if uc.user.msgStore == nil { + return "" + } + + // Don't store messages with a server mask target + if strings.HasPrefix(entity, "$") { + return "" + } + + entityCM := uc.network.casemap(entity) + if entityCM == "nickserv" { + // The messages sent/received from NickServ may contain + // security-related information (like passwords). Don't store these. + return "" + } + + if !uc.network.delivered.HasTarget(entity) { + // This is the first message we receive from this target. Save the last + // message ID in delivery receipts, so that we can send the new message + // in the backlog if an offline client reconnects. + lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now()) + if err != nil { + uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) + return "" + } + + uc.network.delivered.ForEachClient(func(clientName string) { + uc.network.delivered.StoreID(entity, clientName, lastID) + }) + } + + msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg) + if err != nil { + uc.logger.Printf("failed to append message to store: %v", err) + return "" + } + + return msgID +} + +// produce appends a message to the logs and forwards it to connected downstream +// connections. +// +// If origin is not nil and origin doesn't support echo-message, the message is +// forwarded to all connections except origin. +func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) { + var msgID string + if target != "" { + msgID = uc.appendLog(target, msg) + } + + // Don't forward messages if it's a detached channel + ch := uc.network.channels.Value(target) + detached := ch != nil && ch.Detached + + uc.forEachDownstream(func(dc *downstreamConn) { + if !detached && (dc != origin || dc.caps["echo-message"]) { + dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID) + } else { + dc.advanceMessageWithID(msg, msgID) + } + }) +} + +func (uc *upstreamConn) updateAway() { + ctx := context.TODO() + + away := true + uc.forEachDownstream(func(*downstreamConn) { + away = false + }) + if away == uc.away { + return + } + if away { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + Params: []string{"Auto away"}, + }) + } else { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + }) + } + uc.away = away +} + +func (uc *upstreamConn) updateChannelAutoDetach(name string) { + uch := uc.channels.Value(name) + if uch == nil { + return + } + ch := uc.network.channels.Value(name) + if ch == nil || ch.Detached { + return + } + uch.updateAutoDetach(ch.DetachAfter) +} + +func (uc *upstreamConn) updateMonitor() { + if _, ok := uc.isupport["MONITOR"]; !ok { + return + } + + ctx := context.TODO() + + add := make(map[string]struct{}) + var addList []string + seen := make(map[string]struct{}) + uc.forEachDownstream(func(dc *downstreamConn) { + for targetCM := range dc.monitored.innerMap { + if !uc.monitored.Has(targetCM) { + if _, ok := add[targetCM]; !ok { + addList = append(addList, targetCM) + add[targetCM] = struct{}{} + } + } else { + seen[targetCM] = struct{}{} + } + } + }) + + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) { + addList = append(addList, wantNickCM) + add[wantNickCM] = struct{}{} + } + + removeAll := true + var removeList []string + for targetCM, entry := range uc.monitored.innerMap { + if _, ok := seen[targetCM]; ok { + removeAll = false + } else { + removeList = append(removeList, entry.originalKey) + } + } + + // TODO: better handle the case where len(uc.monitored) + len(addList) + // exceeds the limit, probably by immediately sending ERR_MONLISTFULL? + + if removeAll && len(addList) == 0 && len(removeList) > 0 { + // Optimization when the last MONITOR-aware downstream disconnects + uc.SendMessage(ctx, &irc.Message{ + Command: "MONITOR", + Params: []string{"C"}, + }) + } else { + msgs := generateMonitor("-", removeList) + msgs = append(msgs, generateMonitor("+", addList)...) + for _, msg := range msgs { + uc.SendMessage(ctx, msg) + } + } + + for _, target := range removeList { + uc.monitored.Delete(target) + } +} diff --git a/branches/origin-master/user.go b/branches/origin-master/user.go new file mode 100644 index 0000000..0b523ce --- /dev/null +++ b/branches/origin-master/user.go @@ -0,0 +1,1068 @@ +package suika + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "math/big" + "net" + "sort" + "strings" + "time" + + "gopkg.in/irc.v3" +) + +type event interface{} + +type eventUpstreamMessage struct { + msg *irc.Message + uc *upstreamConn +} + +type eventUpstreamConnectionError struct { + net *network + err error +} + +type eventUpstreamConnected struct { + uc *upstreamConn +} + +type eventUpstreamDisconnected struct { + uc *upstreamConn +} + +type eventUpstreamError struct { + uc *upstreamConn + err error +} + +type eventDownstreamMessage struct { + msg *irc.Message + dc *downstreamConn +} + +type eventDownstreamConnected struct { + dc *downstreamConn +} + +type eventDownstreamDisconnected struct { + dc *downstreamConn +} + +type eventChannelDetach struct { + uc *upstreamConn + name string +} + +type eventBroadcast struct { + msg *irc.Message +} + +type eventStop struct{} + +type eventUserUpdate struct { + password *string + admin *bool + done chan error +} + +type deliveredClientMap map[string]string // client name -> msg ID + +type deliveredStore struct { + m deliveredCasemapMap +} + +func newDeliveredStore() deliveredStore { + return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}} +} + +func (ds deliveredStore) HasTarget(target string) bool { + return ds.m.Value(target) != nil +} + +func (ds deliveredStore) LoadID(target, clientName string) string { + clients := ds.m.Value(target) + if clients == nil { + return "" + } + return clients[clientName] +} + +func (ds deliveredStore) StoreID(target, clientName, msgID string) { + clients := ds.m.Value(target) + if clients == nil { + clients = make(deliveredClientMap) + ds.m.SetValue(target, clients) + } + clients[clientName] = msgID +} + +func (ds deliveredStore) ForEachTarget(f func(target string)) { + for _, entry := range ds.m.innerMap { + f(entry.originalKey) + } +} + +func (ds deliveredStore) ForEachClient(f func(clientName string)) { + clients := make(map[string]struct{}) + for _, entry := range ds.m.innerMap { + delivered := entry.value.(deliveredClientMap) + for clientName := range delivered { + clients[clientName] = struct{}{} + } + } + + for clientName := range clients { + f(clientName) + } +} + +type network struct { + Network + user *user + logger Logger + stopped chan struct{} + + conn *upstreamConn + channels channelCasemapMap + delivered deliveredStore + lastError error + casemap casemapping +} + +func newNetwork(user *user, record *Network, channels []Channel) *network { + logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} + + m := channelCasemapMap{newCasemapMap(0)} + for _, ch := range channels { + ch := ch + m.SetValue(ch.Name, &ch) + } + + return &network{ + Network: *record, + user: user, + logger: logger, + stopped: make(chan struct{}), + channels: m, + delivered: newDeliveredStore(), + casemap: casemapRFC1459, + } +} + +func (net *network) forEachDownstream(f func(*downstreamConn)) { + net.user.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + if dc.network != nil && dc.network != net { + return + } + f(dc) + }) +} + +func (net *network) isStopped() bool { + select { + case <-net.stopped: + return true + default: + return false + } +} + +func userIdent(u *User) string { + // The ident is a string we will send to upstream servers in clear-text. + // For privacy reasons, make sure it doesn't expose any meaningful user + // metadata. We just use the base64-encoded hashed ID, so that people don't + // start relying on the string being an integer or following a pattern. + var b [64]byte + binary.LittleEndian.PutUint64(b[:], uint64(u.ID)) + h := sha256.Sum256(b[:]) + return hex.EncodeToString(h[:16]) +} + +func (net *network) run() { + if !net.Enabled { + return + } + + var lastTry time.Time + backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter) + for { + if net.isStopped() { + return + } + + delay := backoff.Next() - time.Now().Sub(lastTry) + if delay > 0 { + net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr) + time.Sleep(delay) + } + lastTry = time.Now() + + + uc, err := connectToUpstream(context.TODO(), net) + if err != nil { + net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)} + continue + } + + uc.register(context.TODO()) + if err := uc.runUntilRegistered(context.TODO()); err != nil { + text := err.Error() + temp := true + if regErr, ok := err.(registrationError); ok { + text = regErr.Reason() + temp = regErr.Temporary() + } + uc.logger.Printf("failed to register: %v", text) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)} + uc.Close() + if !temp { + return + } + continue + } + + // TODO: this is racy with net.stopped. If the network is stopped + // before the user goroutine receives eventUpstreamConnected, the + // connection won't be closed. + net.user.events <- eventUpstreamConnected{uc} + if err := uc.readMessages(net.user.events); err != nil { + uc.logger.Printf("failed to handle messages: %v", err) + net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)} + } + uc.Close() + net.user.events <- eventUpstreamDisconnected{uc} + + backoff.Reset() + } +} + +func (net *network) stop() { + if !net.isStopped() { + close(net.stopped) + } + + if net.conn != nil { + net.conn.Close() + } +} + +func (net *network) detach(ch *Channel) { + if ch.Detached { + return + } + + net.logger.Printf("detaching channel %q", ch.Name) + + ch.Detached = true + + if net.user.msgStore != nil { + nameCM := net.casemap(ch.Name) + lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now()) + if err != nil { + net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err) + } + ch.DetachedInternalMsgID = lastID + } + + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "PART", + Params: []string{dc.marshalEntity(net, ch.Name), "Detach"}, + }) + }) +} + +func (net *network) attach(ctx context.Context, ch *Channel) { + if !ch.Detached { + return + } + + net.logger.Printf("attaching channel %q", ch.Name) + + detachedMsgID := ch.DetachedInternalMsgID + ch.Detached = false + ch.DetachedInternalMsgID = "" + + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + + net.conn.updateChannelAutoDetach(ch.Name) + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(net, ch.Name)}, + }) + + if uch != nil { + forwardChannel(ctx, dc, uch) + } + + if detachedMsgID != "" { + dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID) + } + }) +} + +func (net *network) deleteChannel(ctx context.Context, name string) error { + ch := net.channels.Value(name) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil { + return err + } + net.channels.Delete(name) + return nil +} + +func (net *network) updateCasemapping(newCasemap casemapping) { + net.casemap = newCasemap + net.channels.SetCasemapping(newCasemap) + net.delivered.m.SetCasemapping(newCasemap) + if uc := net.conn; uc != nil { + uc.channels.SetCasemapping(newCasemap) + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.Members.SetCasemapping(newCasemap) + } + uc.monitored.SetCasemapping(newCasemap) + } + net.forEachDownstream(func(dc *downstreamConn) { + dc.monitored.SetCasemapping(newCasemap) + }) +} + +func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) { + if !net.user.hasPersistentMsgStore() { + return + } + + var receipts []DeliveryReceipt + net.delivered.ForEachTarget(func(target string) { + msgID := net.delivered.LoadID(target, clientName) + if msgID == "" { + return + } + receipts = append(receipts, DeliveryReceipt{ + Target: target, + InternalMsgID: msgID, + }) + }) + + if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil { + net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err) + } +} + +func (net *network) isHighlight(msg *irc.Message) bool { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return false + } + + text := msg.Params[1] + + nick := net.Nick + if net.conn != nil { + nick = net.conn.nick + } + + // TODO: use case-mapping aware comparison here + return msg.Prefix.Name != nick && isHighlight(text, nick) +} + +func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool { + highlight := net.isHighlight(msg) + return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight) +} + +func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) { + // User may have e.g. EXTERNAL mechanism configured. We do not want to + // automatically erase the key pair or any other credentials. + if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" { + return + } + + net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username) + net.SASL.Mechanism = "PLAIN" + net.SASL.Plain.Username = username + net.SASL.Plain.Password = password + if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil { + net.logger.Printf("failed to save SASL PLAIN credentials: %v", err) + } +} + +type user struct { + User + srv *Server + logger Logger + + events chan event + done chan struct{} + + networks []*network + downstreamConns []*downstreamConn + msgStore messageStore +} + +func newUser(srv *Server, record *User) *user { + logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} + + var msgStore messageStore + if logPath := srv.Config().LogPath; logPath != "" { + msgStore = newFSMessageStore(logPath, record) + } else { + msgStore = newMemoryMessageStore() + } + + return &user{ + User: *record, + srv: srv, + logger: logger, + events: make(chan event, 64), + done: make(chan struct{}), + msgStore: msgStore, + } +} + +func (u *user) forEachUpstream(f func(uc *upstreamConn)) { + for _, network := range u.networks { + if network.conn == nil { + continue + } + f(network.conn) + } +} + +func (u *user) forEachDownstream(f func(dc *downstreamConn)) { + for _, dc := range u.downstreamConns { + f(dc) + } +} + +func (u *user) getNetwork(name string) *network { + for _, network := range u.networks { + if network.Addr == name { + return network + } + if network.Name != "" && network.Name == name { + return network + } + } + return nil +} + +func (u *user) getNetworkByID(id int64) *network { + for _, net := range u.networks { + if net.ID == id { + return net + } + } + return nil +} + +func (u *user) run() { + defer func() { + if u.msgStore != nil { + if err := u.msgStore.Close(); err != nil { + u.logger.Printf("failed to close message store for user %q: %v", u.Username, err) + } + } + close(u.done) + }() + + networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID) + if err != nil { + u.logger.Printf("failed to list networks for user %q: %v", u.Username, err) + return + } + + sort.Slice(networks, func(i, j int) bool { + return networks[i].ID < networks[j].ID + }) + + for _, record := range networks { + record := record + channels, err := u.srv.db.ListChannels(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err) + continue + } + + network := newNetwork(u, &record, channels) + u.networks = append(u.networks, network) + + if u.hasPersistentMsgStore() { + receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err) + return + } + + for _, rcpt := range receipts { + network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID) + } + } + + go network.run() + } + + for e := range u.events { + switch e := e.(type) { + case eventUpstreamConnected: + uc := e.uc + + uc.network.conn = uc + + uc.updateAway() + uc.updateMonitor() + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + }) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=connected"}, + }) + } + }) + uc.network.lastError = nil + case eventUpstreamDisconnected: + u.handleUpstreamDisconnected(e.uc) + case eventUpstreamConnectionError: + net := e.net + + stopped := false + select { + case <-net.stopped: + stopped = true + default: + } + + if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) { + net.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err)) + }) + } + net.lastError = e.err + case eventUpstreamError: + uc := e.uc + + uc.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err)) + }) + uc.network.lastError = e.err + case eventUpstreamMessage: + msg, uc := e.msg, e.uc + if uc.isClosed() { + uc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + if err := uc.handleMessage(context.TODO(), msg); err != nil { + uc.logger.Printf("failed to handle message %q: %v", msg, err) + } + case eventChannelDetach: + uc, name := e.uc, e.name + c := uc.network.channels.Value(name) + if c == nil || c.Detached { + continue + } + uc.network.detach(c) + if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil { + u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err) + } + case eventDownstreamConnected: + dc := e.dc + + if dc.network != nil { + dc.monitored.SetCasemapping(dc.network.casemap) + } + + if err := dc.welcome(context.TODO()); err != nil { + dc.logger.Printf("failed to handle new registered connection: %v", err) + break + } + + u.downstreamConns = append(u.downstreamConns, dc) + + dc.forEachNetwork(func(network *network) { + if network.lastError != nil { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError)) + } + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.updateAway() + }) + case eventDownstreamDisconnected: + dc := e.dc + + for i := range u.downstreamConns { + if u.downstreamConns[i] == dc { + u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) + break + } + } + + dc.forEachNetwork(func(net *network) { + net.storeClientDeliveryReceipts(context.TODO(), dc.clientName) + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.cancelPendingCommandsByDownstreamID(dc.id) + uc.updateAway() + uc.updateMonitor() + }) + case eventDownstreamMessage: + msg, dc := e.msg, e.dc + if dc.isClosed() { + dc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + err := dc.handleMessage(context.TODO(), msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + dc.logger.Printf("failed to handle message %q: %v", msg, err) + dc.Close() + } + case eventBroadcast: + msg := e.msg + u.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case eventUserUpdate: + // copy the user record because we'll mutate it + record := u.User + + if e.password != nil { + record.Password = *e.password + } + if e.admin != nil { + record.Admin = *e.admin + } + + e.done <- u.updateUser(context.TODO(), &record) + + // If the password was updated, kill all downstream connections to + // force them to re-authenticate with the new credentials. + if e.password != nil { + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + } + case eventStop: + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + for _, n := range u.networks { + n.stop() + + n.delivered.ForEachClient(func(clientName string) { + n.storeClientDeliveryReceipts(context.TODO(), clientName) + }) + } + return + default: + panic(fmt.Sprintf("received unknown event type: %T", e)) + } + } +} + +func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { + uc.network.conn = nil + + uc.abortPendingCommands() + + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.updateAutoDetach(0) + } + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + + // If the network has been removed, don't send a state change notification + found := false + for _, net := range u.networks { + if net == uc.network { + found = true + break + } + } + if !found { + return + } + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=disconnected"}, + }) + } + }) + + if uc.network.lastError == nil { + uc.forEachDownstream(func(dc *downstreamConn) { + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) + } + }) + } +} + +func (u *user) addNetwork(network *network) { + u.networks = append(u.networks, network) + + sort.Slice(u.networks, func(i, j int) bool { + return u.networks[i].ID < u.networks[j].ID + }) + + go network.run() +} + +func (u *user) removeNetwork(network *network) { + network.stop() + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.Close() + } + }) + + for i, net := range u.networks { + if net == network { + u.networks = append(u.networks[:i], u.networks[i+1:]...) + return + } + } + + panic("tried to remove a non-existing network") +} + +func (u *user) checkNetwork(record *Network) error { + url, err := record.URL() + if err != nil { + return err + } + if url.User != nil { + return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme) + } + if url.RawQuery != "" { + return fmt.Errorf("%v:// URL must not have query values", url.Scheme) + } + if url.Fragment != "" { + return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme) + } + switch url.Scheme { + case "ircs", "irc": + if url.Host == "" { + return fmt.Errorf("%v:// URL must have a host", url.Scheme) + } + if url.Path != "" { + return fmt.Errorf("%v:// URL must not have a path", url.Scheme) + } + case "irc+unix", "unix": + if url.Host != "" { + return fmt.Errorf("%v:// URL must not have a host", url.Scheme) + } + if url.Path == "" { + return fmt.Errorf("%v:// URL must have a path", url.Scheme) + } + default: + return fmt.Errorf("unknown URL scheme %q", url.Scheme) + } + + if record.GetName() == "" { + return fmt.Errorf("network name cannot be empty") + } + if strings.HasPrefix(record.GetName(), "-") { + // Can be mixed up with flags when sending commands to the service + return fmt.Errorf("network name cannot start with a dash character") + } + + for _, net := range u.networks { + if net.GetName() == record.GetName() && net.ID != record.ID { + return fmt.Errorf("a network with the name %q already exists", record.GetName()) + } + } + + return nil +} + +func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID != 0 { + panic("tried creating an already-existing network") + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max { + return nil, fmt.Errorf("maximum number of networks reached") + } + + network := newNetwork(u, record, nil) + err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network) + if err != nil { + return nil, err + } + + u.addNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return network, nil +} + +func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID == 0 { + panic("tried updating a new network") + } + + // If the realname is reset to the default, just wipe the per-network + // setting + if record.Realname == u.Realname { + record.Realname = "" + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + network := u.getNetworkByID(record.ID) + if network == nil { + panic("tried updating a non-existing network") + } + + if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil { + return nil, err + } + + // Most network changes require us to re-connect to the upstream server + + channels := make([]Channel, 0, network.channels.Len()) + for _, entry := range network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, *ch) + } + + updatedNetwork := newNetwork(u, record, channels) + + // If we're currently connected, disconnect and perform the necessary + // bookkeeping + if network.conn != nil { + network.stop() + // Note: this will set network.conn to nil + u.handleUpstreamDisconnected(network.conn) + } + + // Patch downstream connections to use our fresh updated network + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.network = updatedNetwork + } + }) + + // We need to remove the network after patching downstream connections, + // otherwise they'll get closed + u.removeNetwork(network) + + // The filesystem message store needs to be notified whenever the network + // is renamed + fsMsgStore, isFS := u.msgStore.(*fsMessageStore) + if isFS && updatedNetwork.GetName() != network.GetName() { + if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil { + network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err) + } + } + + // This will re-connect to the upstream server + u.addNetwork(updatedNetwork) + + // TODO: only broadcast attributes that have changed + idStr := fmt.Sprintf("%v", updatedNetwork.ID) + attrs := getNetworkAttrs(updatedNetwork) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return updatedNetwork, nil +} + +func (u *user) deleteNetwork(ctx context.Context, id int64) error { + network := u.getNetworkByID(id) + if network == nil { + panic("tried deleting a non-existing network") + } + + if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil { + return err + } + + u.removeNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, "*"}, + }) + } + }) + + return nil +} + +func (u *user) updateUser(ctx context.Context, record *User) error { + if u.ID != record.ID { + panic("ID mismatch when updating user") + } + + realnameUpdated := u.Realname != record.Realname + if err := u.srv.db.StoreUser(ctx, record); err != nil { + return fmt.Errorf("failed to update user %q: %v", u.Username, err) + } + u.User = *record + + if realnameUpdated { + // Re-connect to networks which use the default realname + var needUpdate []Network + for _, net := range u.networks { + if net.Realname == "" { + needUpdate = append(needUpdate, net.Network) + } + } + + var netErr error + for _, net := range needUpdate { + if _, err := u.updateNetwork(ctx, &net); err != nil { + netErr = err + } + } + if netErr != nil { + return netErr + } + } + + return nil +} + +func (u *user) stop() { + u.events <- eventStop{} + <-u.done +} + +func (u *user) hasPersistentMsgStore() bool { + if u.msgStore == nil { + return false + } + _, isMem := u.msgStore.(*memoryMessageStore) + return !isMem +} + +// localAddrForHost returns the local address to use when connecting to host. +// A nil address is returned when the OS should automatically pick one. +func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) { + upstreamUserIPs := u.srv.Config().UpstreamUserIPs + if len(upstreamUserIPs) == 0 { + return nil, nil + } + + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) + if err != nil { + return nil, err + } + + wantIPv6 := false + for _, ip := range ips { + if ip.To4() == nil { + wantIPv6 = true + break + } + } + + var ipNet *net.IPNet + for _, in := range upstreamUserIPs { + if wantIPv6 == (in.IP.To4() == nil) { + ipNet = in + break + } + } + if ipNet == nil { + return nil, nil + } + + var ipInt big.Int + ipInt.SetBytes(ipNet.IP) + ipInt.Add(&ipInt, big.NewInt(u.ID+1)) + ip := net.IP(ipInt.Bytes()) + if !ipNet.Contains(ip) { + return nil, fmt.Errorf("IP network %v too small", ipNet) + } + + return &net.TCPAddr{IP: ip}, nil +} diff --git a/branches/origin-master/version.go b/branches/origin-master/version.go new file mode 100644 index 0000000..0e5d7a6 --- /dev/null +++ b/branches/origin-master/version.go @@ -0,0 +1,50 @@ +package suika + +import ( + "fmt" + "runtime/debug" + "strings" +) + +const ( + defaultVersion = "0.0.0" + defaultCommit = "HEAD" + defaultBuild = "0000-01-01:00:00+00:00" +) + +var ( + // Version is the tagged release version in the form .. + // following semantic versioning and is overwritten by the build system. + Version = defaultVersion + + // Commit is the commit sha of the build (normally from Git) and is overwritten + // by the build system. + Commit = defaultCommit + + // Build is the date and time of the build as an RFC3339 formatted string + // and is overwritten by the build system. + Build = defaultBuild +) + +// FullVersion display the full version and build +func FullVersion() string { + var sb strings.Builder + + isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild + + if !isDefault { + sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build)) + } + + if info, ok := debug.ReadBuildInfo(); ok { + if isDefault { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Version)) + } + sb.WriteString(fmt.Sprintf(" %s", info.GoVersion)) + if info.Main.Sum != "" { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum)) + } + } + + return sb.String() +} diff --git a/branches/origin/.gitignore b/branches/origin/.gitignore new file mode 100644 index 0000000..4aa22f5 --- /dev/null +++ b/branches/origin/.gitignore @@ -0,0 +1,5 @@ +vendor +/suika +/suikadb +/suika-znc-import +/suika.db diff --git a/branches/origin/LICENSE b/branches/origin/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/branches/origin/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/branches/origin/Makefile b/branches/origin/Makefile new file mode 100644 index 0000000..a77bc97 --- /dev/null +++ b/branches/origin/Makefile @@ -0,0 +1,45 @@ +GO ?= go +RM ?= rm +GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor +PREFIX ?= /usr/local +BINDIR ?= bin +MANDIR ?= share/man +MKDIR ?= mkdir +CP ?= cp +SYSCONFDIR ?= /etc +ASCIIDOCTOR ?= asciidoctor + +VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"` +COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"` +BRANCH = `git rev-parse --abbrev-ref HEAD` +BUILD = `git show -s --pretty=format:%cI` + +GOARCH ?= amd64 +GOOS ?= linux + +all: build + +build: vendor + ${GO} build ${GOFLAGS} ./cmd/suika + ${GO} build ${GOFLAGS} ./cmd/suikadb + ${GO} build ${GOFLAGS} ./cmd/suika-znc-import +clean: + ${RM} -f suika suikadb suika-znc-import +install: + ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR} + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7 + ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika + ${MKDIR} -p ${DESTDIR}/var/lib/suika + ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR} + ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1 + ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5 + [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config +test: + go test +vendor: + go mod vendor +.PHONY: build clean install diff --git a/branches/origin/README.md b/branches/origin/README.md new file mode 100644 index 0000000..2171ba5 --- /dev/null +++ b/branches/origin/README.md @@ -0,0 +1,33 @@ +# suika + +[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika) + +A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power) + +- Multi-user +- Support multiple clients for a single user, with proper backlog + synchronization +- Support connecting to multiple upstream servers via a single IRC connection + to the bouncer + +## Building and installing + +Dependencies: + +- Go +- BSD or GNU make + +For end users, a `Makefile` is provided: + + make + doas make install + +For development, you can use `go run ./cmd/suika` as usual. + +## License +AGPLv3, see [LICENSE](LICENSE). + +* Copyright (C) 2020 The soju Contributors +* Copyright (C) 2023-present Izuru Yakumo + +The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT diff --git a/branches/origin/bridge.go b/branches/origin/bridge.go new file mode 100644 index 0000000..7bb7aa8 --- /dev/null +++ b/branches/origin/bridge.go @@ -0,0 +1,116 @@ +package suika + +import ( + "context" + "fmt" + "strconv" + "strings" + + "gopkg.in/irc.v3" +) + +func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) { + if !ch.complete { + panic("Tried to forward a partial channel") + } + + // RPL_NOTOPIC shouldn't be sent on JOIN + if ch.Topic != "" { + sendTopic(dc, ch) + } + + if dc.caps["soju.im/read"] { + channelCM := ch.conn.network.casemap(ch.Name) + r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err) + } else { + timestampStr := "*" + if r != nil { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr}, + }) + } + } + + sendNames(dc, ch) +} + +func sendTopic(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + if ch.Topic != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_TOPIC, + Params: []string{dc.nick, downstreamName, ch.Topic}, + }) + if ch.TopicWho != nil { + topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho) + topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime}, + }) + } + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NOTOPIC, + Params: []string{dc.nick, downstreamName, "No topic is set"}, + }) + } +} + +func sendNames(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + emptyNameReply := &irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, ""}, + } + maxLength := maxMessageLength - len(emptyNameReply.String()) + + var buf strings.Builder + for _, entry := range ch.Members.innerMap { + nick := entry.originalKey + memberships := entry.value.(*memberships) + s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick) + + n := buf.Len() + 1 + len(s) + if buf.Len() != 0 && n > maxLength { + // There's not enough space for the next space + nick. + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + buf.Reset() + } + + if buf.Len() != 0 { + buf.WriteByte(' ') + } + buf.WriteString(s) + } + + if buf.Len() != 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, downstreamName, "End of /NAMES list"}, + }) +} diff --git a/branches/origin/certfp.go b/branches/origin/certfp.go new file mode 100644 index 0000000..9180f7c --- /dev/null +++ b/branches/origin/certfp.go @@ -0,0 +1,72 @@ +package suika + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" +) + +func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) { + var ( + privKey crypto.PrivateKey + pubKey crypto.PublicKey + ) + switch keyType { + case "rsa": + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ecdsa": + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ed25519": + var err error + pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + } + + // Using PKCS#8 allows easier extension for new key types. + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey) + if err != nil { + return nil, nil, err + } + + notBefore := time.Now() + // Lets make a fair assumption nobody will use the same cert for more than 20 years... + notAfter := notBefore.Add(24 * time.Hour * 365 * 20) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, err + } + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: "suika auto-generated certificate"}, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey) + if err != nil { + return nil, nil, err + } + + return privKeyBytes, certBytes, nil +} diff --git a/branches/origin/cmd/suika-znc-import/main.go b/branches/origin/cmd/suika-znc-import/main.go new file mode 100644 index 0000000..78fefc6 --- /dev/null +++ b/branches/origin/cmd/suika-znc-import/main.go @@ -0,0 +1,474 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "net/url" + "os" + "strings" + "unicode" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +const usage = `usage: suika-znc-import [options...] + +Imports configuration from a ZNC file. Users and networks are merged if they +already exist in the suika database. ZNC settings overwrite existing suika +settings. + +Options: + + -help Show this help message + -config Path to suika config file + -user Limit import to username (may be specified multiple times) + -network Limit import to network (may be specified multiple times) +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + users := make(map[string]bool) + networks := make(map[string]bool) + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Var((*stringSetFlag)(&users), "user", "") + flag.Var((*stringSetFlag)(&networks), "network", "") + flag.Parse() + + zncPath := flag.Arg(0) + if zncPath == "" { + flag.Usage() + os.Exit(1) + } + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + ctx := context.Background() + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + f, err := os.Open(zncPath) + if err != nil { + log.Fatalf("failed to open ZNC configuration file: %v", err) + } + defer f.Close() + + zp := zncParser{bufio.NewReader(f), 1} + root, err := zp.sectionBody("", "") + if err != nil { + log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err) + } + + l, err := db.ListUsers(ctx) + if err != nil { + log.Fatalf("failed to list users in DB: %v", err) + } + existingUsers := make(map[string]*suika.User, len(l)) + for i, u := range l { + existingUsers[u.Username] = &l[i] + } + + usersCreated := 0 + usersImported := 0 + networksImported := 0 + channelsImported := 0 + root.ForEach("User", func(section *zncSection) { + username := section.Name + if len(users) > 0 && !users[username] { + return + } + usersImported++ + + u, ok := existingUsers[username] + if ok { + log.Printf("user %q: updating existing user", username) + } else { + // "!!" is an invalid crypt format, thus disables password auth + u = &suika.User{Username: username, Password: "!!"} + usersCreated++ + log.Printf("user %q: creating new user", username) + } + + u.Admin = section.Values.Get("Admin") == "true" + + if err := db.StoreUser(ctx, u); err != nil { + log.Fatalf("failed to store user %q: %v", username, err) + } + userID := u.ID + + l, err := db.ListNetworks(ctx, userID) + if err != nil { + log.Fatalf("failed to list networks for user %q: %v", username, err) + } + existingNetworks := make(map[string]*suika.Network, len(l)) + for i, n := range l { + existingNetworks[n.GetName()] = &l[i] + } + + nick := section.Values.Get("Nick") + realname := section.Values.Get("RealName") + ident := section.Values.Get("Ident") + + section.ForEach("Network", func(section *zncSection) { + netName := section.Name + if len(networks) > 0 && !networks[netName] { + return + } + networksImported++ + + logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName) + logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix) + + netNick := section.Values.Get("Nick") + if netNick == "" { + netNick = nick + } + netRealname := section.Values.Get("RealName") + if netRealname == "" { + netRealname = realname + } + netIdent := section.Values.Get("Ident") + if netIdent == "" { + netIdent = ident + } + + for _, name := range section.Values["LoadModule"] { + switch name { + case "sasl": + logger.Printf("warning: SASL credentials not imported") + case "nickserv": + logger.Printf("warning: NickServ credentials not imported") + case "perform": + logger.Printf("warning: \"perform\" plugin commands not imported") + } + } + + u, pass, err := importNetworkServer(section.Values.Get("Server")) + if err != nil { + logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err) + } + + n, ok := existingNetworks[netName] + if ok { + logger.Printf("updating existing network") + } else { + n = &suika.Network{Name: netName} + logger.Printf("creating new network") + } + + n.Addr = u.String() + n.Nick = netNick + n.Username = netIdent + n.Realname = netRealname + n.Pass = pass + n.Enabled = section.Values.Get("IRCConnectEnabled") != "false" + + if err := db.StoreNetwork(ctx, userID, n); err != nil { + logger.Fatalf("failed to store network: %v", err) + } + + l, err := db.ListChannels(ctx, n.ID) + if err != nil { + logger.Fatalf("failed to list channels: %v", err) + } + existingChannels := make(map[string]*suika.Channel, len(l)) + for i, ch := range l { + existingChannels[ch.Name] = &l[i] + } + + section.ForEach("Chan", func(section *zncSection) { + chName := section.Name + + if section.Values.Get("Disabled") == "true" { + logger.Printf("skipping import of disabled channel %q", chName) + return + } + + channelsImported++ + + ch, ok := existingChannels[chName] + if ok { + logger.Printf("channel %q: updating existing channel", chName) + } else { + ch = &suika.Channel{Name: chName} + logger.Printf("channel %q: creating new channel", chName) + } + + ch.Key = section.Values.Get("Key") + ch.Detached = section.Values.Get("Detached") == "true" + + if err := db.StoreChannel(ctx, n.ID, ch); err != nil { + logger.Printf("channel %q: failed to store channel: %v", chName, err) + } + }) + }) + }) + + if err := db.Close(); err != nil { + log.Printf("failed to close database: %v", err) + } + + if usersCreated > 0 { + log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password `") + } + + log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported) +} + +func importNetworkServer(s string) (u *url.URL, pass string, err error) { + parts := strings.Fields(s) + if len(parts) < 2 { + return nil, "", fmt.Errorf("expected space-separated host and port") + } + + scheme := "irc" + host := parts[0] + port := parts[1] + if strings.HasPrefix(port, "+") { + port = port[1:] + scheme = "ircs" + } + + if len(parts) > 2 { + pass = parts[2] + } + + u = &url.URL{ + Scheme: scheme, + Host: host + ":" + port, + } + return u, pass, nil +} + +type zncSection struct { + Type string + Name string + Values zncValues + Children []zncSection +} + +func (s *zncSection) ForEach(typ string, f func(*zncSection)) { + for _, section := range s.Children { + if section.Type == typ { + f(§ion) + } + } +} + +type zncValues map[string][]string + +func (zv zncValues) Get(k string) string { + if len(zv[k]) == 0 { + return "" + } + return zv[k][0] +} + +type zncParser struct { + br *bufio.Reader + line int +} + +func (zp *zncParser) readByte() (byte, error) { + b, err := zp.br.ReadByte() + if b == '\n' { + zp.line++ + } + return b, err +} + +func (zp *zncParser) readRune() (rune, int, error) { + r, n, err := zp.br.ReadRune() + if r == '\n' { + zp.line++ + } + return r, n, err +} + +func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) { + section := &zncSection{Type: typ, Name: name, Values: make(zncValues)} + +Loop: + for { + if err := zp.skipSpace(); err != nil { + return nil, err + } + + b, err := zp.br.Peek(2) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + switch b[0] { + case '<': + if b[1] == '/' { + break Loop + } else { + childType, childName, err := zp.sectionHeader() + if err != nil { + return nil, err + } + child, err := zp.sectionBody(childType, childName) + if err != nil { + return nil, err + } + if footerType, err := zp.sectionFooter(); err != nil { + return nil, err + } else if footerType != childType { + return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType) + } + section.Children = append(section.Children, *child) + } + case '/': + if b[1] == '/' { + if err := zp.skipComment(); err != nil { + return nil, err + } + break + } + fallthrough + default: + k, v, err := zp.keyValuePair() + if err != nil { + return nil, err + } + section.Values[k] = append(section.Values[k], v) + } + } + + return section, nil +} + +func (zp *zncParser) skipSpace() error { + for { + r, _, err := zp.readRune() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if !unicode.IsSpace(r) { + zp.br.UnreadRune() + return nil + } + } +} + +func (zp *zncParser) skipComment() error { + if err := zp.expectRune('/'); err != nil { + return err + } + if err := zp.expectRune('/'); err != nil { + return err + } + + for { + b, err := zp.readByte() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if b == '\n' { + return nil + } + } +} + +func (zp *zncParser) sectionHeader() (string, string, error) { + if err := zp.expectRune('<'); err != nil { + return "", "", err + } + typ, err := zp.readWord(' ') + if err != nil { + return "", "", err + } + name, err := zp.readWord('>') + return typ, name, err +} + +func (zp *zncParser) sectionFooter() (string, error) { + if err := zp.expectRune('<'); err != nil { + return "", err + } + if err := zp.expectRune('/'); err != nil { + return "", err + } + return zp.readWord('>') +} + +func (zp *zncParser) keyValuePair() (string, string, error) { + k, err := zp.readWord('=') + if err != nil { + return "", "", err + } + v, err := zp.readWord('\n') + return strings.TrimSpace(k), strings.TrimSpace(v), err +} + +func (zp *zncParser) expectRune(expected rune) error { + r, _, err := zp.readRune() + if err != nil { + return err + } else if r != expected { + return fmt.Errorf("expected %q, got %q", expected, r) + } + return nil +} + +func (zp *zncParser) readWord(delim byte) (string, error) { + var sb strings.Builder + for { + b, err := zp.readByte() + if err != nil { + return "", err + } + + if b == delim { + return sb.String(), nil + } + if b == '\n' { + return "", fmt.Errorf("expected %q before newline", delim) + } + + sb.WriteByte(b) + } +} + +type stringSetFlag map[string]bool + +func (v *stringSetFlag) String() string { + return fmt.Sprint(map[string]bool(*v)) +} + +func (v *stringSetFlag) Set(s string) error { + (*v)[s] = true + return nil +} diff --git a/branches/origin/cmd/suika/main.go b/branches/origin/cmd/suika/main.go new file mode 100644 index 0000000..9661859 --- /dev/null +++ b/branches/origin/cmd/suika/main.go @@ -0,0 +1,229 @@ +package main + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/url" + "os" + "os/signal" + "strings" + "sync/atomic" + "syscall" + "time" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +// TCP keep-alive interval for downstream TCP connections +const downstreamKeepAlive = 1 * time.Hour + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +func bumpOpenedFileLimit() error { + var rlimit syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err) + } + rlimit.Cur = rlimit.Max + if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err) + } + return nil +} + +var ( + configPath string + debug bool + + tlsCert atomic.Value // *tls.Certificate +) + +func loadConfig() (*config.Server, *suika.Config, error) { + var raw *config.Server + if configPath != "" { + var err error + raw, err = config.Load(configPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load config file: %v", err) + } + } else { + raw = config.Defaults() + } + + var motd string + if raw.MOTDPath != "" { + b, err := os.ReadFile(raw.MOTDPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load MOTD: %v", err) + } + motd = strings.TrimSuffix(string(b), "\n") + } + + if raw.TLS != nil { + cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err) + } + tlsCert.Store(&cert) + } + + cfg := &suika.Config{ + Hostname: raw.Hostname, + Title: raw.Title, + LogPath: raw.LogPath, + MaxUserNetworks: raw.MaxUserNetworks, + MultiUpstream: raw.MultiUpstream, + UpstreamUserIPs: raw.UpstreamUserIPs, + MOTD: motd, + } + return raw, cfg, nil +} + +func main() { + var listen []string + flag.Var((*stringSliceFlag)(&listen), "listen", "listening address") + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.BoolVar(&debug, "debug", false, "enable debug logging") + flag.Parse() + + cfg, serverCfg, err := loadConfig() + if err != nil { + log.Fatal(err) + } + + cfg.Listen = append(cfg.Listen, listen...) + if len(cfg.Listen) == 0 { + cfg.Listen = []string{":6667"} + } + + if err := bumpOpenedFileLimit(); err != nil { + log.Printf("failed to bump max number of opened files: %v", err) + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + var tlsCfg *tls.Config + if cfg.TLS != nil { + tlsCfg = &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return tlsCert.Load().(*tls.Certificate), nil + }, + } + } + + srv := suika.NewServer(db) + srv.SetConfig(serverCfg) + srv.Logger = suika.NewLogger(log.Writer(), debug) + + for _, listen := range cfg.Listen { + listen := listen // copy + listenURI := listen + if !strings.Contains(listenURI, ":/") { + // This is a raw domain name, make it an URL with an empty scheme + listenURI = "//" + listenURI + } + u, err := url.Parse(listenURI) + if err != nil { + log.Fatalf("failed to parse listen URI %q: %v", listen, err) + } + + switch u.Scheme { + case "ircs", "": + if tlsCfg == nil { + log.Fatalf("failed to listen on %q: missing TLS configuration", listen) + } + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6697" + } + ircsTLSCfg := tlsCfg.Clone() + ircsTLSCfg.NextProtos = []string{"irc"} + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + l, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start TLS listener on %q: %v", listen, err) + } + ln := tls.NewListener(l, ircsTLSCfg) + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "irc": + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6667" + } + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + ln, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "unix": + ln, err := net.Listen("unix", u.Path) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + default: + log.Fatalf("failed to listen on %q: unsupported scheme", listen) + } + + log.Printf("starting suika version %v\n", suika.FullVersion()) + log.Printf("server listening on %q", listen) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + if err := srv.Start(); err != nil { + log.Fatal(err) + } + + for sig := range sigCh { + switch sig { + case syscall.SIGHUP: + log.Print("reloading configuration") + _, serverCfg, err := loadConfig() + if err != nil { + log.Printf("failed to reloading configuration: %v", err) + } else { + srv.SetConfig(serverCfg) + } + case syscall.SIGINT, syscall.SIGTERM: + log.Print("shutting down server") + srv.Shutdown() + return + } + } +} diff --git a/branches/origin/cmd/suikadb/main.go b/branches/origin/cmd/suikadb/main.go new file mode 100644 index 0000000..4d54481 --- /dev/null +++ b/branches/origin/cmd/suikadb/main.go @@ -0,0 +1,150 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "os" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" +) + +const usage = `usage: suikadb [-config path] [options...] + + create-user [-admin] Create a new user + change-password Change password for a user + help Show this help message +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Parse() + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + ctx := context.Background() + + switch cmd := flag.Arg(0); cmd { + case "create-user": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + fs := flag.NewFlagSet("", flag.ExitOnError) + admin := fs.Bool("admin", false, "make the new user admin") + fs.Parse(flag.Args()[2:]) + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user := suika.User{ + Username: username, + Password: string(hashed), + Admin: *admin, + } + if err := db.StoreUser(ctx, &user); err != nil { + log.Fatalf("failed to create user: %v", err) + } + case "change-password": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + user, err := db.GetUser(ctx, username) + if err != nil { + log.Fatalf("failed to get user: %v", err) + } + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user.Password = string(hashed) + if err := db.StoreUser(ctx, user); err != nil { + log.Fatalf("failed to update password: %v", err) + } + case "version": + fmt.Printf("%v\n", suika.FullVersion()) + default: + flag.Usage() + if cmd != "help" { + os.Exit(1) + } + } +} + +func readPassword() ([]byte, error) { + var password []byte + var err error + fd := int(os.Stdin.Fd()) + + if term.IsTerminal(fd) { + fmt.Printf("Password: ") + password, err = term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + return nil, err + } + fmt.Printf("\n") + } else { + fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n") + // TODO: the buffering messes up repeated calls to readPassword + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, err + } + return nil, io.ErrUnexpectedEOF + } + password = scanner.Bytes() + + if len(password) == 0 { + return nil, fmt.Errorf("zero length password") + } + } + + return password, nil +} diff --git a/branches/origin/config.in b/branches/origin/config.in new file mode 100644 index 0000000..0b1ed4d --- /dev/null +++ b/branches/origin/config.in @@ -0,0 +1,2 @@ +db sqlite3 /var/lib/suika/main.db +log fs /var/lib/suika/logs/ diff --git a/branches/origin/config/config.go b/branches/origin/config/config.go new file mode 100644 index 0000000..94d4584 --- /dev/null +++ b/branches/origin/config/config.go @@ -0,0 +1,142 @@ +package config + +import ( + "fmt" + "net" + "os" + "strconv" + + "git.sr.ht/~emersion/go-scfg" +) + +type TLS struct { + CertPath, KeyPath string +} + +type Server struct { + Listen []string + TLS *TLS + Hostname string + Title string + MOTDPath string + + SQLDriver string + SQLSource string + LogPath string + + MaxUserNetworks int + MultiUpstream bool + UpstreamUserIPs []*net.IPNet +} + +func Defaults() *Server { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + return &Server{ + Hostname: hostname, + SQLDriver: "sqlite3", + SQLSource: "suika.db", + MaxUserNetworks: -1, + MultiUpstream: true, + } +} + +func Load(path string) (*Server, error) { + cfg, err := scfg.Load(path) + if err != nil { + return nil, err + } + return parse(cfg) +} + +func parse(cfg scfg.Block) (*Server, error) { + srv := Defaults() + for _, d := range cfg { + switch d.Name { + case "listen": + var uri string + if err := d.ParseParams(&uri); err != nil { + return nil, err + } + srv.Listen = append(srv.Listen, uri) + case "hostname": + if err := d.ParseParams(&srv.Hostname); err != nil { + return nil, err + } + case "title": + if err := d.ParseParams(&srv.Title); err != nil { + return nil, err + } + case "motd": + if err := d.ParseParams(&srv.MOTDPath); err != nil { + return nil, err + } + case "tls": + tls := &TLS{} + if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil { + return nil, err + } + srv.TLS = tls + case "db": + if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil { + return nil, err + } + case "log": + var driver string + if err := d.ParseParams(&driver, &srv.LogPath); err != nil { + return nil, err + } + if driver != "fs" { + return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver) + } + case "max-user-networks": + var max string + if err := d.ParseParams(&max); err != nil { + return nil, err + } + var err error + if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + case "multi-upstream-mode": + var str string + if err := d.ParseParams(&str); err != nil { + return nil, err + } + v, err := strconv.ParseBool(str) + if err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + srv.MultiUpstream = v + case "upstream-user-ip": + if len(srv.UpstreamUserIPs) > 0 { + return nil, fmt.Errorf("directive %q: can only be specified once", d.Name) + } + var hasIPv4, hasIPv6 bool + for _, s := range d.Params { + _, n, err := net.ParseCIDR(s) + if err != nil { + return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err) + } + if n.IP.To4() == nil { + if hasIPv6 { + return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name) + } + hasIPv6 = true + } else { + if hasIPv4 { + return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name) + } + hasIPv4 = true + } + srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) + } + default: + return nil, fmt.Errorf("unknown directive %q", d.Name) + } + } + + return srv, nil +} diff --git a/branches/origin/conn.go b/branches/origin/conn.go new file mode 100644 index 0000000..bb95e25 --- /dev/null +++ b/branches/origin/conn.go @@ -0,0 +1,180 @@ +package suika + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "golang.org/x/time/rate" + "gopkg.in/irc.v3" +) + +// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on +// reading and writing IRC messages. +type ircConn interface { + ReadMessage() (*irc.Message, error) + WriteMessage(*irc.Message) error + Close() error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error + RemoteAddr() net.Addr + LocalAddr() net.Addr +} + +func newNetIRCConn(c net.Conn) ircConn { + type netConn net.Conn + return struct { + *irc.Conn + netConn + }{irc.NewConn(c), c} +} + +type connOptions struct { + Logger Logger + RateLimitDelay time.Duration + RateLimitBurst int +} + +type conn struct { + conn ircConn + srv *Server + logger Logger + + lock sync.Mutex + outgoing chan<- *irc.Message + closed bool + closedCh chan struct{} +} + +func newConn(srv *Server, ic ircConn, options *connOptions) *conn { + outgoing := make(chan *irc.Message, 64) + c := &conn{ + conn: ic, + srv: srv, + outgoing: outgoing, + logger: options.Logger, + closedCh: make(chan struct{}), + } + + go func() { + ctx, cancel := c.NewContext(context.Background()) + defer cancel() + + rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst) + for msg := range outgoing { + if err := rl.Wait(ctx); err != nil { + break + } + + c.logger.Debugf("sent: %v", msg) + c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := c.conn.WriteMessage(msg); err != nil { + c.logger.Printf("failed to write message: %v", err) + break + } + } + if err := c.conn.Close(); err != nil && !isErrClosed(err) { + c.logger.Printf("failed to close connection: %v", err) + } else { + c.logger.Debugf("connection closed") + } + // Drain the outgoing channel to prevent SendMessage from blocking + for range outgoing { + // This space is intentionally left blank + } + }() + + c.logger.Debugf("new connection") + return c +} + +func (c *conn) isClosed() bool { + c.lock.Lock() + defer c.lock.Unlock() + return c.closed +} + +// Close closes the connection. It is safe to call from any goroutine. +func (c *conn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return fmt.Errorf("connection already closed") + } + + err := c.conn.Close() + c.closed = true + close(c.outgoing) + close(c.closedCh) + return err +} + +func (c *conn) ReadMessage() (*irc.Message, error) { + msg, err := c.conn.ReadMessage() + if isErrClosed(err) { + return nil, io.EOF + } else if err != nil { + return nil, err + } + + c.logger.Debugf("received: %v", msg) + return msg, nil +} + +// SendMessage queues a new outgoing message. It is safe to call from any +// goroutine. +// +// If the connection is closed before the message is sent, SendMessage silently +// drops the message. +func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return + } + + select { + case c.outgoing <- msg: + // Success + case <-ctx.Done(): + c.logger.Printf("failed to send message: %v", ctx.Err()) + } +} + +func (c *conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// NewContext returns a copy of the parent context with a new Done channel. The +// returned context's Done channel is closed when the connection is closed, +// when the returned cancel function is called, or when the parent context's +// Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(parent) + + go func() { + defer cancel() + + select { + case <-ctx.Done(): + // The parent context has been cancelled, or the caller has called + // cancel() + case <-c.closedCh: + // The connection has been closed + } + }() + + return ctx, cancel +} diff --git a/branches/origin/contrib/casemap-logs.sh b/branches/origin/contrib/casemap-logs.sh new file mode 100644 index 0000000..709d581 --- /dev/null +++ b/branches/origin/contrib/casemap-logs.sh @@ -0,0 +1,42 @@ +#!/bin/sh -eu + +# Converts a log dir to its case-mapped form. +# +# suika needs to be stopped for this script to work properly. The script may +# re-order messages that happened within the same second interval if merging +# two daily log files is necessary. +# +# usage: casemap-logs.sh + +root="$1" + +for net_dir in "$root"/*/*; do + for chan in $(ls "$net_dir"); do + cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')" + if [ "$chan" = "$cm_chan" ]; then + continue + fi + + if ! [ -d "$net_dir/$cm_chan" ]; then + echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + mv "$net_dir/$chan" "$net_dir/$cm_chan" + continue + fi + + echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + for day in $(ls "$net_dir/$chan"); do + if ! [ -e "$net_dir/$cm_chan/$day" ]; then + echo >&2 " Moving log file: '$day'" + mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" + continue + fi + + echo >&2 " Merging log file: '$day'" + sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new" + mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day" + rm "$net_dir/$chan/$day" + done + + rmdir "$net_dir/$chan" + done +done diff --git a/branches/origin/contrib/clients.md b/branches/origin/contrib/clients.md new file mode 100644 index 0000000..1f8881d --- /dev/null +++ b/branches/origin/contrib/clients.md @@ -0,0 +1,100 @@ +# Clients + +This page describes how to configure IRC clients to better integrate with soju. + +Also see the [IRCv3 support tables] for a more general list of clients. + +# catgirl + +catgirl doesn't properly implement cap-3.2, so many capabilities will be +disabled. catgirl developers have publicly stated that supporting bouncers such +as soju is a non-goal. + +# [Emacs] + +There are two clients provided with Emacs. They require some setup to work +properly. + +## Erc + +You need to explicitly set the username, which is the defcustom +`erc-email-userid`. + +```elisp +(setq erc-email-userid "/irc.libera.chat") ;; Example with Libera.Chat +(defun run-erc () + (interactive) + (erc-tls :server "" + :port 6697 + :nick "" + :password "")) +``` + +Then run `M-x run-erc`. + +## Rcirc + +The only thing needed here is the general config: + +```elisp +(setq rcirc-server-alist + '(("" + :port 6697 + :encryption tls + :nick "" + :user-name "/irc.libera.chat" ;; Example with Libera.Chat + :password ""))) +``` + +Then run `M-x irc`. + +# [gamja] + +gamja has been designed together with soju, so should have excellent +integration. gamja supports many IRCv3 features including chat history. +gamja also provides UI to manage soju networks via the +`soju.im/bouncer-networks` extension. + +# [goguma] + +Much like gamja, goguma has been designed together with soju, so should have +excellent integration. goguma supports many IRCv3 features including chat +history. goguma should seamlessly connect to all networks configured in soju via +the `soju.im/bouncer-networks` extension. + +# [Hexchat] + +Hexchat has support for a small set of IRCv3 capabilities. To prevent +automatically reconnecting to channels parted from soju, and prevent buffering +outgoing messages: + + /set irc_reconnect_rejoin off + /set net_throttle off + +# [senpai] + +senpai is being developed with soju in mind, so should have excellent +integration. senpai supports many IRCv3 features including chat history. + +# [Weechat] + +A [Weechat script] is available to provide better integration with soju. +The script will automatically connect to all of your networks once a +single connection to soju is set up in Weechat. + +On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them: + + /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names + /save + /reconnect -all + +See `/help cap` for more information. + +[IRCv3 support tables]: https://ircv3.net/software/clients +[gamja]: https://sr.ht/~emersion/gamja/ +[goguma]: https://sr.ht/~emersion/goguma/ +[senpai]: https://sr.ht/~taiite/senpai/ +[Weechat]: https://weechat.org/ +[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py +[Hexchat]: https://hexchat.github.io/ +[Emacs]: https://www.gnu.org/software/emacs/ diff --git a/branches/origin/db.go b/branches/origin/db.go new file mode 100644 index 0000000..95ac964 --- /dev/null +++ b/branches/origin/db.go @@ -0,0 +1,184 @@ +package suika + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" +) + +type Database interface { + Close() error + Stats(ctx context.Context) (*DatabaseStats, error) + + ListUsers(ctx context.Context) ([]User, error) + GetUser(ctx context.Context, username string) (*User, error) + StoreUser(ctx context.Context, user *User) error + DeleteUser(ctx context.Context, id int64) error + + ListNetworks(ctx context.Context, userID int64) ([]Network, error) + StoreNetwork(ctx context.Context, userID int64, network *Network) error + DeleteNetwork(ctx context.Context, id int64) error + ListChannels(ctx context.Context, networkID int64) ([]Channel, error) + StoreChannel(ctx context.Context, networKID int64, ch *Channel) error + DeleteChannel(ctx context.Context, id int64) error + + ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) + StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error + + GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) + StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error +} + +func OpenDB(driver, source string) (Database, error) { + switch driver { + case "sqlite3": + return OpenSqliteDB(source) + case "postgres": + return OpenPostgresDB(source) + default: + return nil, fmt.Errorf("unsupported database driver: %q", driver) + } +} + +type DatabaseStats struct { + Users int64 + Networks int64 + Channels int64 +} + +type User struct { + ID int64 + Username string + Password string // hashed + Realname string + Admin bool +} + +type SASL struct { + Mechanism string + + Plain struct { + Username string + Password string + } + + // TLS client certificate authentication. + External struct { + // X.509 certificate in DER form. + CertBlob []byte + // PKCS#8 private key in DER form. + PrivKeyBlob []byte + } +} + +type Network struct { + ID int64 + Name string + Addr string + Nick string + Username string + Realname string + Pass string + ConnectCommands []string + SASL SASL + Enabled bool +} + +func (net *Network) GetName() string { + if net.Name != "" { + return net.Name + } + return net.Addr +} + +func (net *Network) URL() (*url.URL, error) { + s := net.Addr + if !strings.Contains(s, "://") { + // This is a raw domain name, make it an URL with the default scheme + s = "ircs://" + s + } + + u, err := url.Parse(s) + if err != nil { + return nil, fmt.Errorf("failed to parse upstream server URL: %v", err) + } + + return u, nil +} + +func GetNick(user *User, net *Network) string { + if net.Nick != "" { + return net.Nick + } + return user.Username +} + +func GetUsername(user *User, net *Network) string { + if net.Username != "" { + return net.Username + } + return GetNick(user, net) +} + +func GetRealname(user *User, net *Network) string { + if net.Realname != "" { + return net.Realname + } + if user.Realname != "" { + return user.Realname + } + return GetNick(user, net) +} + +type MessageFilter int + +const ( + // TODO: use customizable user defaults for FilterDefault + FilterDefault MessageFilter = iota + FilterNone + FilterHighlight + FilterMessage +) + +func parseFilter(filter string) (MessageFilter, error) { + switch filter { + case "default": + return FilterDefault, nil + case "none": + return FilterNone, nil + case "highlight": + return FilterHighlight, nil + case "message": + return FilterMessage, nil + } + return 0, fmt.Errorf("unknown filter: %q", filter) +} + +type Channel struct { + ID int64 + Name string + Key string + + Detached bool + DetachedInternalMsgID string + + RelayDetached MessageFilter + ReattachOn MessageFilter + DetachAfter time.Duration + DetachOn MessageFilter +} + +type DeliveryReceipt struct { + ID int64 + Target string // channel or nick + Client string + InternalMsgID string +} + +type ReadReceipt struct { + ID int64 + Target string // channel or nick + Timestamp time.Time +} diff --git a/branches/origin/db_postgres.go b/branches/origin/db_postgres.go new file mode 100644 index 0000000..0849338 --- /dev/null +++ b/branches/origin/db_postgres.go @@ -0,0 +1,494 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "errors" + "fmt" + "math" + "strings" + "time" + + _ "github.com/lib/pq" +) + +const postgresQueryTimeout = 5 * time.Second + +const postgresConfigSchema = ` +CREATE TABLE IF NOT EXISTS "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); +` +//go:embed suika_psql_schema.sql +var postgresSchema string + +var postgresMigrations = []string{ + "", // migration #0 is reserved for schema initialization + `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`, + ` + CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + ALTER TABLE "Network" + ALTER COLUMN sasl_mechanism + TYPE sasl_mechanism + USING sasl_mechanism::sasl_mechanism; + `, + ` + CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) + ); + `, +} + +type PostgresDB struct { + db *sql.DB +} + +func OpenPostgresDB(source string) (Database, error) { + sqlPostgresDB, err := sql.Open("postgres", source) + if err != nil { + return nil, err + } + + db := &PostgresDB{db: sqlPostgresDB} + if err := db.upgrade(); err != nil { + sqlPostgresDB.Close() + return nil, err + } + + return db, nil +} + +func (db *PostgresDB) upgrade() error { + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec(postgresConfigSchema); err != nil { + return fmt.Errorf("failed to create Config table: %s", err) + } + + var version int + err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to query schema version: %s", err) + } + + if version == len(postgresMigrations) { + return nil + } + if version > len(postgresMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version) + } + + if version == 0 { + if _, err := tx.Exec(postgresSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %s", err) + } + } else { + for i := version; i < len(postgresMigrations); i++ { + if _, err := tx.Exec(postgresMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1) + ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations)) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *PostgresDB) Close() error { + return db.db.Close() +} + +func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM "User") AS users, + (SELECT COUNT(*) FROM "Network") AS networks, + (SELECT COUNT(*) FROM "Channel") AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT id, username, password, admin, realname FROM "User"`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + `SELECT id, password, admin, realname FROM "User" WHERE username = $1`, + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + password := toNullString(user.Password) + realname := toNullString(user.Realname) + + var err error + if user.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "User" (username, password, admin, realname) + VALUES ($1, $2, $3, $4) + RETURNING id`, + user.Username, password, user.Admin, realname).Scan(&user.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "User" + SET password = $1, admin = $2, realname = $3 + WHERE id = $4`, + password, user.Admin, realname, user.ID) + } + return err +} + +func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, + sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled + FROM "Network" + WHERE "user" = $1`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + netName := toNullString(network.Name) + nick := toNullString(network.Nick) + netUsername := toNullString(network.Username) + realname := toNullString(network.Realname) + pass := toNullString(network.Pass) + connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + var err error + if network.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, + sasl_external_key, enabled) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id`, + userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Network" + SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, + connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, + sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13, + enabled = $14 + WHERE id = $1`, + network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled) + } + return err +} + +func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, + detach_on + FROM "Channel" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + key := toNullString(ch.Key) + detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds())) + + var err error + if ch.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, + detach_after, detach_on) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING id`, + networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Channel" + SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5, + relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9 + WHERE id = $1`, + ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn) + } + return err +} + +func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM "DeliveryReceipt" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, + `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`, + networkID, client) + if err != nil { + return err + } + + stmt, err := tx.PrepareContext(ctx, ` + INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid) + VALUES ($1, $2, $3, $4) + RETURNING id`) + if err != nil { + return err + } + defer stmt.Close() + + for i := range receipts { + rcpt := &receipts[i] + err := stmt. + QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID). + Scan(&rcpt.ID) + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, + `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`, + networkID, name) + if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return receipt, nil +} + +func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE "ReadReceipt" + SET timestamp = $1 + WHERE id = $2`, + receipt.Timestamp, receipt.ID) + } else { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "ReadReceipt" (network, target, timestamp) + VALUES ($1, $2, $3) + RETURNING id`, + networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID) + } + return err +} diff --git a/branches/origin/db_postgres_test.go b/branches/origin/db_postgres_test.go new file mode 100644 index 0000000..7692f31 --- /dev/null +++ b/branches/origin/db_postgres_test.go @@ -0,0 +1,104 @@ +package suika + +import ( + "database/sql" + "os" + "testing" +) + +// PostgreSQL version 0 schema. DO NOT EDIT. +const postgresV0Schema = ` +CREATE TABLE "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); + +INSERT INTO "Config" (id, version) VALUES (1, 1); + +CREATE TABLE "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TABLE "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA DEFAULT NULL, + sasl_external_key BYTEA DEFAULT NULL, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); + +CREATE TABLE "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); + +CREATE TABLE "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +` + +func openTempPostgresDB(t *testing.T) *sql.DB { + source, ok := os.LookupEnv("SOJU_TEST_POSTGRES") + if !ok { + t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests") + } + + db, err := sql.Open("postgres", source) + if err != nil { + t.Fatalf("failed to connect to PostgreSQL: %v", err) + } + + // Store all tables in a temporary schema which will be dropped when the + // connection to PostgreSQL is closed. + db.SetMaxOpenConns(1) + if _, err := db.Exec("SET search_path TO pg_temp"); err != nil { + t.Fatalf("failed to set PostgreSQL search_path: %v", err) + } + + return db +} + +func TestPostgresMigrations(t *testing.T) { + sqlDB := openTempPostgresDB(t) + if _, err := sqlDB.Exec(postgresV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &PostgresDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("PostgresDB.Upgrade() failed: %v", err) + } +} diff --git a/branches/origin/db_sqlite.go b/branches/origin/db_sqlite.go new file mode 100644 index 0000000..7ebebb0 --- /dev/null +++ b/branches/origin/db_sqlite.go @@ -0,0 +1,755 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "math" + "strings" + "sync" + "time" + + _ "modernc.org/sqlite" +) + +const sqliteQueryTimeout = 5 * time.Second + +//go:embed suika_sqlite_schema.sql +var sqliteSchema string + +var sqliteMigrations = []string{ + "", // migration #0 is reserved for schema initialization + "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)", + "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL", + "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL", + "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0", + ` + CREATE TABLE IF NOT EXISTS UserNew ( + id INTEGER PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin INTEGER NOT NULL DEFAULT 0 + ); + INSERT INTO UserNew SELECT rowid, username, password, admin FROM User; + DROP TABLE User; + ALTER TABLE UserNew RENAME TO User; + `, + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user INTEGER NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BLOB DEFAULT NULL, + sasl_external_key BLOB DEFAULT NULL, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT Network.id, name, User.id as user, addr, nick, + Network.username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key + FROM Network + JOIN User ON Network.user = User.username; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0; + `, + ` + CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target VARCHAR(255) NOT NULL, + client VARCHAR(255), + internal_msgid VARCHAR(255) NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) + ); + `, + "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)", + "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", + "ALTER TABLE User ADD COLUMN realname VARCHAR(255)", + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT id, name, user, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, + enabled + FROM Network; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) + ); + `, +} + +type SqliteDB struct { + lock sync.RWMutex + db *sql.DB +} + +func OpenSqliteDB(source string) (Database, error) { + sqlSqliteDB, err := sql.Open("sqlite", source) + if err != nil { + return nil, err + } + + db := &SqliteDB{db: sqlSqliteDB} + if err := db.upgrade(); err != nil { + sqlSqliteDB.Close() + return nil, err + } + + return db, nil +} + +func (db *SqliteDB) Close() error { + db.lock.Lock() + defer db.lock.Unlock() + return db.db.Close() +} + +func (db *SqliteDB) upgrade() error { + db.lock.Lock() + defer db.lock.Unlock() + + var version int + if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + return fmt.Errorf("failed to query schema version: %v", err) + } + + if version == len(sqliteMigrations) { + return nil + } else if version > len(sqliteMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version) + } + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if version == 0 { + if _, err := tx.Exec(sqliteSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %v", err) + } + } else { + for i := version; i < len(sqliteMigrations); i++ { + if _, err := tx.Exec(sqliteMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + // For some reason prepared statements don't work here + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations))) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM User) AS users, + (SELECT COUNT(*) FROM Network) AS networks, + (SELECT COUNT(*) FROM Channel) AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func toNullString(s string) sql.NullString { + return sql.NullString{ + String: s, + Valid: s != "", + } +} + +func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + "SELECT id, username, password, admin, realname FROM User") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + "SELECT id, password, admin, realname FROM User WHERE username = ?", + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("username", user.Username), + sql.Named("password", toNullString(user.Password)), + sql.Named("admin", user.Admin), + sql.Named("realname", toNullString(user.Realname)), + } + + var err error + if user.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE User SET password = :password, admin = :admin, + realname = :realname WHERE username = :username`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + User(username, password, admin, realname) + VALUES (:username, :password, :admin, :realname)`, + args...) + if err != nil { + return err + } + user.ID, err = res.LastInsertId() + } + + return err +} + +func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt + WHERE id IN ( + SELECT DeliveryReceipt.id + FROM DeliveryReceipt + JOIN Network ON DeliveryReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt + WHERE id IN ( + SELECT ReadReceipt.id + FROM ReadReceipt + JOIN Network ON ReadReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM Channel + WHERE id IN ( + SELECT Channel.id + FROM Channel + JOIN Network ON Channel.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key, enabled + FROM Network + WHERE user = ?`, + userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + args := []interface{}{ + sql.Named("name", toNullString(network.Name)), + sql.Named("addr", network.Addr), + sql.Named("nick", toNullString(network.Nick)), + sql.Named("username", toNullString(network.Username)), + sql.Named("realname", toNullString(network.Realname)), + sql.Named("pass", toNullString(network.Pass)), + sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))), + sql.Named("sasl_mechanism", saslMechanism), + sql.Named("sasl_plain_username", saslPlainUsername), + sql.Named("sasl_plain_password", saslPlainPassword), + sql.Named("sasl_external_cert", network.SASL.External.CertBlob), + sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob), + sql.Named("enabled", network.Enabled), + + sql.Named("id", network.ID), // only for UPDATE + sql.Named("user", userID), // only for INSERT + } + + var err error + if network.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE Network + SET name = :name, addr = :addr, nick = :nick, username = :username, + realname = :realname, pass = :pass, connect_commands = :connect_commands, + sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password, + sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key, + enabled = :enabled + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO Network(user, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, enabled) + VALUES (:user, :name, :addr, :nick, :username, :realname, :pass, + :connect_commands, :sasl_mechanism, :sasl_plain_username, + :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`, + args...) + if err != nil { + return err + } + network.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, `SELECT + id, name, key, detached, detached_internal_msgid, + relay_detached, reattach_on, detach_after, detach_on + FROM Channel + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("network", networkID), + sql.Named("name", ch.Name), + sql.Named("key", toNullString(ch.Key)), + sql.Named("detached", ch.Detached), + sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)), + sql.Named("relay_detached", ch.RelayDetached), + sql.Named("reattach_on", ch.ReattachOn), + sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))), + sql.Named("detach_on", ch.DetachOn), + + sql.Named("id", ch.ID), // only for UPDATE + } + + var err error + if ch.ID != 0 { + _, err = db.db.ExecContext(ctx, `UPDATE Channel + SET network = :network, name = :name, key = :key, detached = :detached, + detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached, + reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) + VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...) + if err != nil { + return err + } + ch.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id) + return err +} + +func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM DeliveryReceipt + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + var client sql.NullString + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + rcpt.Client = client.String + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?", + networkID, toNullString(client)) + if err != nil { + return err + } + + for i := range receipts { + rcpt := &receipts[i] + + res, err := tx.ExecContext(ctx, ` + INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) + VALUES (:network, :target, :client, :internal_msgid)`, + sql.Named("network", networkID), + sql.Named("target", rcpt.Target), + sql.Named("client", toNullString(client)), + sql.Named("internal_msgid", rcpt.InternalMsgID)) + if err != nil { + return err + } + rcpt.ID, err = res.LastInsertId() + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, ` + SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`, + sql.Named("network", networkID), + sql.Named("target", name), + ) + var timestamp string + if err := row.Scan(&receipt.ID, ×tamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + if t, err := time.Parse(serverTimeLayout, timestamp); err != nil { + return nil, err + } else { + receipt.Timestamp = t + } + return receipt, nil +} + +func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("id", receipt.ID), + sql.Named("timestamp", formatServerTime(receipt.Timestamp)), + sql.Named("network", networkID), + sql.Named("target", receipt.Target), + } + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + ReadReceipt(network, target, timestamp) + VALUES (:network, :target, :timestamp)`, + args...) + if err != nil { + return err + } + receipt.ID, err = res.LastInsertId() + } + + return err +} diff --git a/branches/origin/db_sqlite_test.go b/branches/origin/db_sqlite_test.go new file mode 100644 index 0000000..a7c9794 --- /dev/null +++ b/branches/origin/db_sqlite_test.go @@ -0,0 +1,59 @@ +package suika + +import ( + "database/sql" + "testing" +) + +// SQLite version 0 schema. DO NOT EDIT. +const sqliteV0Schema = ` +CREATE TABLE User ( + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255) +); + +CREATE TABLE Network ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user VARCHAR(255) NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); + +CREATE TABLE Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +PRAGMA user_version = 1; +` + +func TestSqliteMigrations(t *testing.T) { + sqlDB, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + + if _, err := sqlDB.Exec(sqliteV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &SqliteDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("SqliteDB.Upgrade() failed: %v", err) + } +} diff --git a/branches/origin/doc.go b/branches/origin/doc.go new file mode 100644 index 0000000..3f7af51 --- /dev/null +++ b/branches/origin/doc.go @@ -0,0 +1,20 @@ +// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go. +// +// # Copyright (C) 2020 The soju Contributors +// # Copyright (C) 2023-present Izuru Yakumo et al. +// +// suika is covered by the AGPLv3 license: +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . +package suika diff --git a/branches/origin/doc/suika-config.5 b/branches/origin/doc/suika-config.5 new file mode 100644 index 0000000..ac0e9ef --- /dev/null +++ b/branches/origin/doc/suika-config.5 @@ -0,0 +1,70 @@ +.Dd $Mdocdate$ +.Dt SUIKA-CONFIG 5 +.Os +.Sh NAME +.Nm suika-config +.Nd Configuration file for the IRC bouncer +.Sh SYNOPSIS +.Bk -words +listen ircs:// +.Pp +tls cert.pem key.pem +.Pp +hostname example.org +.Ek +.Sh DESCRIPTION +This document describes the format of the configuration +file used by +.Xr suika 1 +.Sh OPTIONS +.Bl -tag -width Ds +.It listen Ar uri +With this you can control on what +ports/protocols +.Xr suika 1 +listens on, it supports +irc (cleartext IRC), ircs (IRC with TLS), and unix +(IRC over Unix domain sockets) +.It hostname Ar hostname +Server hostname, if unset, the system one is used. +.It title Ar title +Server title, this will be sent as the ISUPPORT NETWORK value when +clients don't select a specific network. +.It tls Ar cert Ar key +Enable TLS support, the certificate and key files must be +PEM-encoded. +.It db Ar driver Ar path +Set the database driver for user, network and channel storage. +By default a SQLite 3 database is opened in +.Pa ./suika.db +Supported drivers are sqlite and postgres, the former +expects a path to the database file, and the latter +a space-separated list of key=value parameters, +e.g. host=localhost dbname=suika +.It log fs Ar path +Path to the bouncer logs directory, or empty to disable +logging. +By default, logging is disabled. +.It max-user-networks Ar limit +Maximum number of networks per user, by default +there is no limit. +.It motd Ar path +Path to the MOTD file, its contents are sent to clients +which aren't bound to a particular network. +By default, no MOTD is sent. +.It multi-upstream-mode Ar bool +Globally enable or disable multi-upstream mode. +By default, it is enabled. +.It upstream-user-ip Ar cidr +Enable per-user-IP addresses. +One IPv4 and/or one IPv6 range can be specified in CIDR notation. +One IP address per range will be assigned to each user as the +source address when connecting to an upstream network. +This can be useful to avoid having the whole bouncer banned from +an upstream network because of one malicious user. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin/doc/suika-znc-import.1 b/branches/origin/doc/suika-znc-import.1 new file mode 100644 index 0000000..7acc618 --- /dev/null +++ b/branches/origin/doc/suika-znc-import.1 @@ -0,0 +1,34 @@ +.Dd $Mdocdate$ +.Dt SUIKA-ZNC-IMPORT 1 +.Os +.Sh NAME +.Nm suika-znc-import +.Nd Migration utility for moving from ZNC +.Sh SYNOPSIS +.Nm +.Op Fl config Ar suika config file +.Op Fl user Ar username +.Op Fl network Ar name +.Sh DESCRIPTION +Imports configuration from a ZNC file. +Users and networks are merged if they already exist in the +.Xr suika 1 +database. +ZNC settings overwrite existing +.Xr suika 1 +settings +.Sh OPTIONS +.Bl -tag -width Ds +.It config Ar suika config file +Path to +.Xr suika-config 5 +.It user Ar username +Limit import to username, may be specified multiple times. +It network Ar name +Limit import to network, may be specified multiple times. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin/doc/suika.1 b/branches/origin/doc/suika.1 new file mode 100644 index 0000000..afe6607 --- /dev/null +++ b/branches/origin/doc/suika.1 @@ -0,0 +1,68 @@ +.Dd $Mdocdate$ +.Dt SUIKA 1 +.Os +.Sh NAME +.Nm suika +.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project +.Sh SYNOPSIS +.Nm +.Op Fl config Ar path +.Op Fl debug +.Op Fl listen Ar uri +.Sh DESCRIPTION +.Nm +is an user-friendly IRC bouncer, it connects to upstream +IRC servers on behalf of the user to provide extra features. +.Bl -tag -width 6n +.It Multiple separate users sharing the same bouncer +.It Clients connecting to multiple upstream servers (via a single connection) +.It Sending the backlog with per-client buffers +.El +.Pp +When joining a channel, the channel will be saved +and automatically joined on the next connection. +When registering or authenticating with NickServ, the credentials will be saved +and automatically used on the next connection if the server supports SASL. +When parting a channel with the reason "detach", the channel will be +detached instead of being left. +When all clients are disconnected from the bouncer, +the user is automatically marked as away. +.Pp +.Nm +supports two connection modes: +.Bl -tag -width 6n +.It Single upstream mode +One downstream connection maps to one upstream connection +.Pp +To enable this mode, connect to the bouncer +with the username "/". +.Pp +If the bouncer isn't connected to the upstream server, +it will get automatically added. +.Pp +Then channels can be joined and parted as if +you were directly connected to the upstream server. +.It Multiple upstream mode +One downstream connection maps to multiple upstream connections. +Channels and nicks are suffixed with the network name. +To join a channel, you need to use the suffix too: /join #channel/network. +Same applies to messages sent to users. +.El +.Pp +For per-client history to work, clients need to indicate their name. +This can be done by adding a "@" suffix to the username. +.Pp +.Nm +will reload the configuration file, the TLS certificate/key and +the MOTD file when it receives the HUP signal. +The configuration options listen, db and log cannot be reloaded. +.Pp +Administrators can broadcast a message to all bouncer users via +/notice $ , or via /notice $ in multi-upstream mode. +All currently connected bouncer users will receive the message +from the special BouncerServ service. +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin/doc/suikadb.1 b/branches/origin/doc/suikadb.1 new file mode 100644 index 0000000..5ff4769 --- /dev/null +++ b/branches/origin/doc/suikadb.1 @@ -0,0 +1,16 @@ +.Dd $Mdocdate$ +.Dt SUIKADB 1 +.Os +.Sh NAME +.Nm suikadb +.Nd Basic user manipulation for +.Xr suika 1 +.Sh SYNOPSIS +.Nm +.Op create-user +.Op change-password +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/branches/origin/downstream.go b/branches/origin/downstream.go new file mode 100644 index 0000000..3b9adac --- /dev/null +++ b/branches/origin/downstream.go @@ -0,0 +1,3047 @@ +package suika + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +type ircError struct { + Message *irc.Message +} + +func (err ircError) Error() string { + return err.Message.String() +} + +func newUnknownCommandError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{ + "*", + cmd, + "Unknown command", + }, + }} +} + +func newNeedMoreParamsError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_NEEDMOREPARAMS, + Params: []string{ + "*", + cmd, + "Not enough parameters", + }, + }} +} + +func newChatHistoryError(subcommand string, target string) ircError { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"}, + }} +} + +// authError is an authentication error. +type authError struct { + // Internal error cause. This will not be revealed to the user. + err error + // Error cause which can safely be sent to the user without compromising + // security. + reason string +} + +func (err *authError) Error() string { + return err.err.Error() +} + +func (err *authError) Unwrap() error { + return err.err +} + +// authErrorReason returns the user-friendly reason of an authentication +// failure. +func authErrorReason(err error) string { + if authErr, ok := err.(*authError); ok { + return authErr.reason + } else { + return "Authentication failed" + } +} + +func newInvalidUsernameOrPasswordError(err error) error { + return &authError{ + err: err, + reason: "Invalid username or password", + } +} + +func parseBouncerNetID(subcommand, s string) (int64, error) { + id, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"}, + }} + } + return id, nil +} + +func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) { + u, err := network.URL() + if err != nil { + return + } + + hasHostPort := true + switch u.Scheme { + case "ircs": + attrs["tls"] = irc.TagValue("1") + case "irc": + attrs["tls"] = irc.TagValue("0") + default: // e.g. unix:// + hasHostPort = false + } + if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort { + attrs["host"] = irc.TagValue(host) + attrs["port"] = irc.TagValue(port) + } else if hasHostPort { + attrs["host"] = irc.TagValue(u.Host) + } +} + +func getNetworkAttrs(network *network) irc.Tags { + state := "disconnected" + if uc := network.conn; uc != nil { + state = "connected" + } + + attrs := irc.Tags{ + "name": irc.TagValue(network.GetName()), + "state": irc.TagValue(state), + "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)), + } + + if network.Username != "" { + attrs["username"] = irc.TagValue(network.Username) + } + if realname := GetRealname(&network.user.User, &network.Network); realname != "" { + attrs["realname"] = irc.TagValue(realname) + } + + fillNetworkAddrAttrs(attrs, &network.Network) + + return attrs +} + +func networkAddrFromAttrs(attrs irc.Tags) string { + host, ok := attrs.GetTag("host") + if !ok { + return "" + } + + addr := host + if port, ok := attrs.GetTag("port"); ok { + addr += ":" + port + } + + if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" { + addr = "irc://" + tlsStr + } + + return addr +} + +func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error { + addrAttrs := irc.Tags{} + fillNetworkAddrAttrs(addrAttrs, record) + + updateAddr := false + for k, v := range attrs { + s := string(v) + switch k { + case "host", "port", "tls": + updateAddr = true + addrAttrs[k] = v + case "name": + record.Name = s + case "nickname": + record.Nick = s + case "username": + record.Username = s + case "realname": + record.Realname = s + case "pass": + record.Pass = s + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"}, + }} + } + } + + if updateAddr { + record.Addr = networkAddrFromAttrs(addrAttrs) + if record.Addr == "" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"}, + }} + } + } + + return nil +} + +// illegalNickChars is the list of characters forbidden in a nickname. +// +// ' ' and ':' break the IRC message wire format +// '@' and '!' break prefixes +// '*' breaks masks and is the reserved nickname for registration +// '?' breaks masks +// '$' breaks server masks in PRIVMSG/NOTICE +// ',' breaks lists +// '.' is reserved for server names +const illegalNickChars = " :@!*?$,." + +// permanentDownstreamCaps is the list of always-supported downstream +// capabilities. +var permanentDownstreamCaps = map[string]string{ + "batch": "", + "cap-notify": "", + "echo-message": "", + "invite-notify": "", + "message-tags": "", + "server-time": "", + "setname": "", + + "soju.im/bouncer-networks": "", + "soju.im/bouncer-networks-notify": "", + "soju.im/read": "", +} + +// needAllDownstreamCaps is the list of downstream capabilities that +// require support from all upstreams to be enabled +var needAllDownstreamCaps = map[string]string{ + "account-notify": "", + "account-tag": "", + "away-notify": "", + "extended-join": "", + "multi-prefix": "", + + "draft/extended-monitor": "", +} + +// passthroughIsupport is the set of ISUPPORT tokens that are directly passed +// through from the upstream server to downstream clients. +// +// This is only effective in single-upstream mode. +var passthroughIsupport = map[string]bool{ + "AWAYLEN": true, + "BOT": true, + "CHANLIMIT": true, + "CHANMODES": true, + "CHANNELLEN": true, + "CHANTYPES": true, + "CLIENTTAGDENY": true, + "ELIST": true, + "EXCEPTS": true, + "EXTBAN": true, + "HOSTLEN": true, + "INVEX": true, + "KICKLEN": true, + "MAXLIST": true, + "MAXTARGETS": true, + "MODES": true, + "MONITOR": true, + "NAMELEN": true, + "NETWORK": true, + "NICKLEN": true, + "PREFIX": true, + "SAFELIST": true, + "TARGMAX": true, + "TOPICLEN": true, + "USERLEN": true, + "UTF8ONLY": true, + "WHOX": true, +} + +type downstreamSASL struct { + server sasl.Server + plainUsername, plainPassword string + pendingResp bytes.Buffer +} + +type downstreamConn struct { + conn + + id uint64 + + registered bool + user *user + nick string + nickCM string + rawUsername string + networkName string + clientName string + realname string + hostname string + account string // RPL_LOGGEDIN/OUT state + password string // empty after authentication + network *network // can be nil + isMultiUpstream bool + + negotiatingCaps bool + capVersion int + supportedCaps map[string]string + caps map[string]bool + sasl *downstreamSASL + + lastBatchRef uint64 + + monitored casemapMap +} + +func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { + remoteAddr := ic.RemoteAddr().String() + logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} + options := connOptions{Logger: logger} + dc := &downstreamConn{ + conn: *newConn(srv, ic, &options), + id: id, + nick: "*", + nickCM: "*", + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + monitored: newCasemapMap(0), + } + dc.hostname = remoteAddr + if host, _, err := net.SplitHostPort(dc.hostname); err == nil { + dc.hostname = host + } + for k, v := range permanentDownstreamCaps { + dc.supportedCaps[k] = v + } + dc.supportedCaps["sasl"] = "PLAIN" + // TODO: this is racy, we should only enable chathistory after + // authentication and then check that user.msgStore implements + // chatHistoryMessageStore + if srv.Config().LogPath != "" { + dc.supportedCaps["draft/chathistory"] = "" + } + return dc +} + +func (dc *downstreamConn) prefix() *irc.Prefix { + return &irc.Prefix{ + Name: dc.nick, + User: dc.user.Username, + Host: dc.hostname, + } +} + +func (dc *downstreamConn) forEachNetwork(f func(*network)) { + if dc.network != nil { + f(dc.network) + } else if dc.isMultiUpstream { + for _, network := range dc.user.networks { + f(network) + } + } +} + +func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + dc.user.forEachUpstream(func(uc *upstreamConn) { + if dc.network != nil && uc.network != dc.network { + return + } + f(uc) + }) +} + +// upstream returns the upstream connection, if any. If there are zero or if +// there are multiple upstream connections, it returns nil. +func (dc *downstreamConn) upstream() *upstreamConn { + if dc.network == nil { + return nil + } + return dc.network.conn +} + +func isOurNick(net *network, nick string) bool { + // TODO: this doesn't account for nick changes + if net.conn != nil { + return net.casemap(nick) == net.conn.nickCM + } + // We're not currently connected to the upstream connection, so we don't + // know whether this name is our nickname. Best-effort: use the network's + // configured nickname and hope it was the one being used when we were + // connected. + return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network)) +} + +// marshalEntity converts an upstream entity name (ie. channel or nick) into a +// downstream entity name. +// +// This involves adding a "/" suffix if the entity isn't the current +// user. +func (dc *downstreamConn) marshalEntity(net *network, name string) string { + if isOurNick(net, name) { + return dc.nick + } + name = partialCasemap(net.casemap, name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal an entity for another network") + } + return name + } + return name + "/" + net.GetName() +} + +func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix { + if isOurNick(net, prefix.Name) { + return dc.prefix() + } + prefix.Name = partialCasemap(net.casemap, prefix.Name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal a user prefix for another network") + } + return prefix + } + return &irc.Prefix{ + Name: prefix.Name + "/" + net.GetName(), + User: prefix.User, + Host: prefix.Host, + } +} + +// unmarshalEntityNetwork converts a downstream entity name (ie. channel or +// nick) into an upstream entity name. +// +// This involves removing the "/" suffix. +func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) { + if dc.network != nil { + return dc.network, name, nil + } + if !dc.isMultiUpstream { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"}, + }} + } + + var net *network + if i := strings.LastIndexByte(name, '/'); i >= 0 { + network := name[i+1:] + name = name[:i] + + for _, n := range dc.user.networks { + if network == n.GetName() { + net = n + break + } + } + } + + if net == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Missing network suffix in name"}, + }} + } + + return net, name, nil +} + +// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the +// upstream connection and fails if the upstream is disconnected. +func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) { + net, name, err := dc.unmarshalEntityNetwork(name) + if err != nil { + return nil, "", err + } + + if net.conn == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Disconnected from upstream network"}, + }} + } + + return net.conn, name, nil +} + +func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string { + if dc.upstream() != nil { + return text + } + // TODO: smarter parsing that ignores URLs + return strings.ReplaceAll(text, "/"+uc.network.GetName(), "") +} + +func (dc *downstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := dc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (dc *downstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := dc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventDownstreamMessage{msg, dc} + } + + return nil +} + +// SendMessage sends an outgoing message. +// +// This can only called from the user goroutine. +func (dc *downstreamConn) SendMessage(msg *irc.Message) { + if !dc.caps["message-tags"] { + if msg.Command == "TAGMSG" { + return + } + msg = msg.Copy() + for name := range msg.Tags { + supported := false + switch name { + case "time": + supported = dc.caps["server-time"] + case "account": + supported = dc.caps["account"] + } + if !supported { + delete(msg.Tags, name) + } + } + } + if !dc.caps["batch"] && msg.Tags["batch"] != "" { + msg = msg.Copy() + delete(msg.Tags, "batch") + } + if msg.Command == "JOIN" && !dc.caps["extended-join"] { + msg.Params = msg.Params[:1] + } + if msg.Command == "SETNAME" && !dc.caps["setname"] { + return + } + if msg.Command == "AWAY" && !dc.caps["away-notify"] { + return + } + if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] { + return + } + if msg.Command == "READ" && !dc.caps["soju.im/read"] { + return + } + + dc.conn.SendMessage(context.TODO(), msg) +} + +func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) { + dc.lastBatchRef++ + ref := fmt.Sprintf("%v", dc.lastBatchRef) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Tags: tags, + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: append([]string{"+" + ref, typ}, params...), + }) + } + + f(irc.TagValue(ref)) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: []string{"-" + ref}, + }) + } +} + +// sendMessageWithID sends an outgoing message with the specified internal ID. +func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { + dc.SendMessage(msg) + + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// advanceMessageWithID advances history to the specified message ID without +// sending a message. This is useful e.g. for self-messages when echo-message +// isn't enabled. +func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) { + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// ackMsgID acknowledges that a message has been received. +func (dc *downstreamConn) ackMsgID(id string) { + netID, entity, err := parseMsgID(id, nil) + if err != nil { + dc.logger.Printf("failed to ACK message ID %q: %v", id, err) + return + } + + network := dc.user.getNetworkByID(netID) + if network == nil { + return + } + + network.delivered.StoreID(entity, dc.clientName, id) +} + +func (dc *downstreamConn) sendPing(msgID string) { + token := "suika-msgid-" + msgID + dc.SendMessage(&irc.Message{ + Command: "PING", + Params: []string{token}, + }) +} + +func (dc *downstreamConn) handlePong(token string) { + if !strings.HasPrefix(token, "suika-msgid-") { + dc.logger.Printf("received unrecognized PONG token %q", token) + return + } + msgID := strings.TrimPrefix(token, "suika-msgid-") + dc.ackMsgID(msgID) +} + +// marshalMessage re-formats a message coming from an upstream connection so +// that it's suitable for being sent on this downstream connection. Only +// messages that may appear in logs are supported, except MODE messages which +// may only appear in single-upstream mode. +func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message { + msg = msg.Copy() + msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix) + + if dc.network != nil { + return msg + } + + switch msg.Command { + case "PRIVMSG", "NOTICE", "TAGMSG": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "NICK": + // Nick change for another user + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "JOIN", "PART": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "KICK": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + msg.Params[1] = dc.marshalEntity(net, msg.Params[1]) + case "TOPIC": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "QUIT", "SETNAME": + // This space is intentionally left blank + default: + panic(fmt.Sprintf("unexpected %q message", msg.Command)) + } + + return msg +} + +func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + ctx, cancel := dc.conn.NewContext(ctx) + defer cancel() + + ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout) + defer cancel() + + switch msg.Command { + case "QUIT": + return dc.Close() + default: + if dc.registered { + return dc.handleMessageRegistered(ctx, msg) + } else { + return dc.handleMessageUnregistered(ctx, msg) + } + } +} + +func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "NICK": + var nick string + if err := parseMessageParams(msg, &nick); err != nil { + return err + } + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, nick, "contains illegal characters"}, + }} + } + nickCM := casemapASCII(nick) + if nickCM == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"}, + }} + } + dc.nick = nick + dc.nickCM = nickCM + case "USER": + if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil { + return err + } + case "PASS": + if err := parseMessageParams(msg, &dc.password); err != nil { + return err + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "AUTHENTICATE": + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } else if credentials == nil { + break + } + + if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil { + dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err) + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, authErrorReason(err)}, + }) + break + } + + // Technically we should send RPL_LOGGEDIN here. However we use + // RPL_LOGGEDIN to mirror the upstream connection status. Let's + // see how many clients that breaks. See: + // https://github.com/ircv3/ircv3-specifications/pull/476 + dc.endSASL(nil) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + + if dc.user == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"}, + }} + } + + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + var match *network + for _, net := range dc.user.networks { + if net.ID == id { + match = net + break + } + } + if match == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"}, + }} + } + + dc.networkName = match.GetName() + } + default: + dc.logger.Printf("unhandled message: %v", msg) + return newUnknownCommandError(msg.Command) + } + if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps { + return dc.register(ctx) + } + return nil +} + +func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { + cmd = strings.ToUpper(cmd) + + switch cmd { + case "LS": + if len(args) > 0 { + var err error + if dc.capVersion, err = strconv.Atoi(args[0]); err != nil { + return err + } + } + if !dc.registered && dc.capVersion >= 302 { + // Let downstream show everything it supports, and trim + // down the available capabilities when upstreams are + // known. + for k, v := range needAllDownstreamCaps { + dc.supportedCaps[k] = v + } + } + + caps := make([]string, 0, len(dc.supportedCaps)) + for k, v := range dc.supportedCaps { + if dc.capVersion >= 302 && v != "" { + caps = append(caps, k+"="+v) + } else { + caps = append(caps, k) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LS", strings.Join(caps, " ")}, + }) + + if dc.capVersion >= 302 { + // CAP version 302 implicitly enables cap-notify + dc.caps["cap-notify"] = true + } + + if !dc.registered { + dc.negotiatingCaps = true + } + case "LIST": + var caps []string + for name, enabled := range dc.caps { + if enabled { + caps = append(caps, name) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LIST", strings.Join(caps, " ")}, + }) + case "REQ": + if len(args) == 0 { + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"}, + }} + } + + // TODO: atomically ack/nak the whole capability set + caps := strings.Fields(args[0]) + ack := true + for _, name := range caps { + name = strings.ToLower(name) + enable := !strings.HasPrefix(name, "-") + if !enable { + name = strings.TrimPrefix(name, "-") + } + + if enable == dc.caps[name] { + continue + } + + _, ok := dc.supportedCaps[name] + if !ok { + ack = false + break + } + + if name == "cap-notify" && dc.capVersion >= 302 && !enable { + // cap-notify cannot be disabled with CAP version 302 + ack = false + break + } + + dc.caps[name] = enable + } + + reply := "NAK" + if ack { + reply = "ACK" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, reply, args[0]}, + }) + + if !dc.registered { + dc.negotiatingCaps = true + } + case "END": + dc.negotiatingCaps = false + default: + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Unknown CAP command"}, + }} + } + return nil +} + +func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) { + defer func() { + if err != nil { + dc.sasl = nil + } + }() + + if !dc.caps["sasl"] { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, + }} + } + if len(msg.Params) == 0 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Missing AUTHENTICATE argument"}, + }} + } + if msg.Params[0] == "*" { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }} + } + + var resp []byte + if dc.sasl == nil { + mech := strings.ToUpper(msg.Params[0]) + var server sasl.Server + switch mech { + case "PLAIN": + server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { + dc.sasl.plainUsername = username + dc.sasl.plainPassword = password + return nil + })) + default: + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)}, + }} + } + + dc.sasl = &downstreamSASL{server: server} + } else { + chunk := msg.Params[0] + if chunk == "+" { + chunk = "" + } + + if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Response too long"}, + }} + } + + dc.sasl.pendingResp.WriteString(chunk) + + if len(chunk) == maxSASLLength { + return nil, nil // Multi-line response, wait for the next command + } + + resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String()) + if err != nil { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Invalid base64-encoded response"}, + }} + } + + dc.sasl.pendingResp.Reset() + } + + challenge, done, err := dc.sasl.server.Next(resp) + if err != nil { + return nil, err + } else if done { + return dc.sasl, nil + } else { + challengeStr := "+" + if len(challenge) > 0 { + challengeStr = base64.StdEncoding.EncodeToString(challenge) + } + + // TODO: multi-line messages + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "AUTHENTICATE", + Params: []string{challengeStr}, + }) + return nil, nil + } +} + +func (dc *downstreamConn) endSASL(msg *irc.Message) { + if dc.sasl == nil { + return + } + + dc.sasl = nil + + if msg != nil { + dc.SendMessage(msg) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_SASLSUCCESS, + Params: []string{dc.nick, "SASL authentication successful"}, + }) + } +} + +func (dc *downstreamConn) setSupportedCap(name, value string) { + prevValue, hasPrev := dc.supportedCaps[name] + changed := !hasPrev || prevValue != value + dc.supportedCaps[name] = value + + if !dc.caps["cap-notify"] || !changed { + return + } + + cap := name + if value != "" && dc.capVersion >= 302 { + cap = name + "=" + value + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "NEW", cap}, + }) +} + +func (dc *downstreamConn) unsetSupportedCap(name string) { + _, hasPrev := dc.supportedCaps[name] + delete(dc.supportedCaps, name) + delete(dc.caps, name) + + if !dc.caps["cap-notify"] || !hasPrev { + return + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "DEL", name}, + }) +} + +func (dc *downstreamConn) updateSupportedCaps() { + supportedCaps := make(map[string]bool) + for cap := range needAllDownstreamCaps { + supportedCaps[cap] = true + } + dc.forEachUpstream(func(uc *upstreamConn) { + for cap, supported := range supportedCaps { + supportedCaps[cap] = supported && uc.caps[cap] + } + }) + + for cap, supported := range supportedCaps { + if supported { + dc.setSupportedCap(cap, needAllDownstreamCaps[cap]) + } else { + dc.unsetSupportedCap(cap) + } + } + + if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") { + dc.setSupportedCap("sasl", "PLAIN") + } else if dc.network != nil { + dc.unsetSupportedCap("sasl") + } + + if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { + // Strip "before-connect", because we require downstreams to be fully + // connected before attempting account registration. + values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") + for i, v := range values { + if v == "before-connect" { + values = append(values[:i], values[i+1:]...) + break + } + } + dc.setSupportedCap("draft/account-registration", strings.Join(values, ",")) + } else { + dc.unsetSupportedCap("draft/account-registration") + } + + if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil { + dc.setSupportedCap("draft/event-playback", "") + } else { + dc.unsetSupportedCap("draft/event-playback") + } +} + +func (dc *downstreamConn) updateNick() { + if uc := dc.upstream(); uc != nil && uc.nick != dc.nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{uc.nick}, + }) + dc.nick = uc.nick + dc.nickCM = casemapASCII(dc.nick) + } +} + +func (dc *downstreamConn) updateRealname() { + if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{uc.realname}, + }) + dc.realname = uc.realname + } +} + +func (dc *downstreamConn) updateAccount() { + var account string + if dc.network == nil { + account = dc.user.Username + } else if uc := dc.upstream(); uc != nil { + account = uc.account + } else { + return + } + + if dc.account == account || !dc.caps["sasl"] { + return + } + + if account != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDIN, + Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDOUT, + Params: []string{dc.nick, dc.prefix().String(), "You are logged out"}, + }) + } + + dc.account = account +} + +func sanityCheckServer(ctx context.Context, addr string) error { + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr) + if err != nil { + return err + } + + return conn.Close() +} + +func unmarshalUsername(rawUsername string) (username, client, network string) { + username = rawUsername + + i := strings.IndexAny(username, "/@") + j := strings.LastIndexAny(username, "/@") + if i >= 0 { + username = rawUsername[:i] + } + if j >= 0 { + if rawUsername[j] == '@' { + client = rawUsername[j+1:] + } else { + network = rawUsername[j+1:] + } + } + if i >= 0 && j >= 0 && i < j { + if rawUsername[i] == '@' { + client = rawUsername[i+1 : j] + } else { + network = rawUsername[i+1 : j] + } + } + + return username, client, network +} + +func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { + username, clientName, networkName := unmarshalUsername(username) + + u, err := dc.srv.db.GetUser(ctx, username) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err)) + } + + // Password auth disabled + if u.Password == "" { + return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled")) + } + + err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password")) + } + + dc.user = dc.srv.getUser(username) + if dc.user == nil { + return fmt.Errorf("user not active") + } + dc.clientName = clientName + dc.networkName = networkName + return nil +} + +func (dc *downstreamConn) register(ctx context.Context) error { + if dc.registered { + panic("tried to register twice") + } + + if dc.sasl != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + } + + password := dc.password + dc.password = "" + if dc.user == nil { + if password == "" { + if dc.caps["sasl"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"}, + }} + } else { + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, "Authentication required"}, + }} + } + } + + if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil { + dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, authErrorReason(err)}, + }} + } + } + + _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername) + if dc.clientName == "" { + dc.clientName = fallbackClientName + } else if fallbackClientName != "" && dc.clientName != fallbackClientName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Client name mismatch in usernames"}, + }} + } + if dc.networkName == "" { + dc.networkName = fallbackNetworkName + } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Network name mismatch in usernames"}, + }} + } + + dc.registered = true + dc.logger.Printf("registration complete for user %q", dc.user.Username) + return nil +} + +func (dc *downstreamConn) loadNetwork(ctx context.Context) error { + if dc.networkName == "" { + return nil + } + + network := dc.user.getNetwork(dc.networkName) + if network == nil { + addr := dc.networkName + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + dc.logger.Printf("trying to connect to new network %q", addr) + if err := sanityCheckServer(ctx, addr); err != nil { + dc.logger.Printf("failed to connect to %q: %v", addr, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)}, + }} + } + + // Some clients only allow specifying the nickname (and use the + // nickname as a username too). Strip the network name from the + // nickname when auto-saving networks. + nick, _, _ := unmarshalUsername(dc.nick) + + dc.logger.Printf("auto-saving network %q", dc.networkName) + var err error + network, err = dc.user.createNetwork(ctx, &Network{ + Addr: dc.networkName, + Nick: nick, + Enabled: true, + }) + if err != nil { + return err + } + } + + dc.network = network + return nil +} + +func (dc *downstreamConn) welcome(ctx context.Context) error { + if dc.user == nil || !dc.registered { + panic("tried to welcome an unregistered connection") + } + + remoteAddr := dc.conn.RemoteAddr().String() + dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)} + + // TODO: doing this might take some time. We should do it in dc.register + // instead, but we'll potentially be adding a new network and this must be + // done in the user goroutine. + if err := dc.loadNetwork(ctx); err != nil { + return err + } + + if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream { + dc.isMultiUpstream = true + } + + dc.updateSupportedCaps() + + isupport := []string{ + fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit), + "CASEMAPPING=ascii", + } + + if dc.network != nil { + isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID)) + } + if title := dc.srv.Config().Title; dc.network == nil && title != "" { + isupport = append(isupport, "NETWORK="+encodeISUPPORT(title)) + } + if dc.network == nil && !dc.isMultiUpstream { + isupport = append(isupport, "WHOX") + } + + if uc := dc.upstream(); uc != nil { + for k := range passthroughIsupport { + v, ok := uc.isupport[k] + if !ok { + continue + } + if v != nil { + isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v)) + } else { + isupport = append(isupport, k) + } + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WELCOME, + Params: []string{dc.nick, "Welcome to suika, " + dc.nick}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_YOURHOST, + Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MYINFO, + Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) { + dc.SendMessage(msg) + } + if uc := dc.upstream(); uc != nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + string(uc.modes)}, + }) + } + if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+o"}, + }) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + + if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil { + for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) { + dc.SendMessage(msg) + } + } else { + motdHint := "No MOTD" + if dc.network != nil { + motdHint = "Use /motd to read the message of the day" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOMOTD, + Params: []string{dc.nick, motdHint}, + }) + } + + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + } + + dc.forEachUpstream(func(uc *upstreamConn) { + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if !ch.complete { + continue + } + record := uc.network.channels.Value(ch.Name) + if record != nil && record.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)}, + }) + + forwardChannel(ctx, dc, ch) + } + }) + + dc.forEachNetwork(func(net *network) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + // Only send history if we're the first connected client with that name + // for the network + firstClient := true + dc.user.forEachDownstream(func(c *downstreamConn) { + if c != dc && c.clientName == dc.clientName && c.network == dc.network { + firstClient = false + } + }) + if firstClient { + net.delivered.ForEachTarget(func(target string) { + lastDelivered := net.delivered.LoadID(target, dc.clientName) + if lastDelivered == "" { + return + } + + dc.sendTargetBacklog(ctx, net, target, lastDelivered) + + // Fast-forward history to last message + targetCM := net.casemap(target) + lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now()) + if err != nil { + dc.logger.Printf("failed to get last message ID: %v", err) + return + } + net.delivered.StoreID(target, dc.clientName, lastID) + }) + } + }) + + return nil +} + +// messageSupportsBacklog checks whether the provided message can be sent as +// part of an history batch. +func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool { + // Don't replay all messages, because that would mess up client + // state. For instance we just sent the list of users, sending + // PART messages for one of these users would be incorrect. + switch msg.Command { + case "PRIVMSG", "NOTICE": + return true + } + return false +} + +func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + ch := net.channels.Value(target) + + ctx, cancel := context.WithTimeout(ctx, backlogTimeout) + defer cancel() + + targetCM := net.casemap(target) + history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit) + if err != nil { + dc.logger.Printf("failed to send backlog for %q: %v", target, err) + return + } + + dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + if ch != nil && ch.Detached { + if net.detachedMessageNeedsRelay(ch, msg) { + dc.relayDetachedMessage(net, msg) + } + } else { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, net)) + } + } + }) +} + +func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return + } + + sender := msg.Prefix.Name + target, text := msg.Params[0], msg.Params[1] + if net.isHighlight(msg) { + sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } else { + sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } +} + +func (dc *downstreamConn) runUntilRegistered() error { + ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout) + defer cancel() + + // Close the connection with an error if the deadline is exceeded + go func() { + <-ctx.Done() + if err := ctx.Err(); err == context.DeadlineExceeded { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "ERROR", + Params: []string{"Connection registration timed out"}, + }) + dc.Close() + } + }() + + for !dc.registered { + msg, err := dc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read IRC command: %w", err) + } + + err = dc.handleMessage(ctx, msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + return fmt.Errorf("failed to handle IRC command %q: %v", msg, err) + } + } + + return nil +} + +func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "PING": + var source, destination string + if err := parseMessageParams(msg, &source); err != nil { + return err + } + if len(msg.Params) > 1 { + destination = msg.Params[1] + } + hostname := dc.srv.Config().Hostname + if destination != "" && destination != hostname { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHSERVER, + Params: []string{dc.nick, destination, "No such server"}, + }} + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "PONG", + Params: []string{hostname, source}, + }) + return nil + case "PONG": + if len(msg.Params) == 0 { + return newNeedMoreParamsError(msg.Command) + } + token := msg.Params[len(msg.Params)-1] + dc.handlePong(token) + case "USER": + return ircError{&irc.Message{ + Command: irc.ERR_ALREADYREGISTERED, + Params: []string{dc.nick, "You may not reregister"}, + }} + case "NICK": + var rawNick string + if err := parseMessageParams(msg, &rawNick); err != nil { + return err + } + + nick := rawNick + var upstream *upstreamConn + if dc.upstream() == nil { + uc, unmarshaledNick, err := dc.unmarshalEntity(nick) + if err == nil { // NICK nick/network: NICK only on a specific upstream + upstream = uc + nick = unmarshaledNick + } + } + + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, rawNick, "contains illegal characters"}, + }} + } + if casemapASCII(nick) == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"}, + }} + } + + var err error + dc.forEachNetwork(func(n *network) { + if err != nil || (upstream != nil && upstream.network != n) { + return + } + n.Nick = nick + err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network) + }) + if err != nil { + return err + } + + dc.forEachUpstream(func(uc *upstreamConn) { + if upstream != nil && upstream != uc { + return + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NICK", + Params: []string{nick}, + }) + }) + + if dc.upstream() == nil && upstream == nil && dc.nick != nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{nick}, + }) + dc.nick = nick + dc.nickCM = casemapASCII(dc.nick) + } + case "SETNAME": + var realname string + if err := parseMessageParams(msg, &realname); err != nil { + return err + } + + // If the client just resets to the default, just wipe the per-network + // preference + storeRealname := realname + if realname == dc.user.Realname { + storeRealname = "" + } + + var storeErr error + var needUpdate []Network + dc.forEachNetwork(func(n *network) { + // We only need to call updateNetwork for upstreams that don't + // support setname + if uc := n.conn; uc != nil && uc.caps["setname"] { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "SETNAME", + Params: []string{realname}, + }) + + n.Realname = storeRealname + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil { + dc.logger.Printf("failed to store network realname: %v", err) + storeErr = err + } + return + } + + record := n.Network // copy network record because we'll mutate it + record.Realname = storeRealname + needUpdate = append(needUpdate, record) + }) + + // Walk the network list as a second step, because updateNetwork + // mutates the original list + for _, record := range needUpdate { + if _, err := dc.user.updateNetwork(ctx, &record); err != nil { + dc.logger.Printf("failed to update network realname: %v", err) + storeErr = err + } + } + if storeErr != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"}, + }} + } + + if dc.upstream() == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{realname}, + }) + } + case "JOIN": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var keys []string + if len(msg.Params) > 1 { + keys = strings.Split(msg.Params[1], ",") + } + + for i, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + var key string + if len(keys) > i { + key = keys[i] + } + + if !uc.isChannel(upstreamName) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, "Not a channel name"}, + }) + continue + } + + // Most servers ignore duplicate JOIN messages. We ignore them here + // because some clients automatically send JOIN messages in bulk + // when reconnecting to the bouncer. We don't want to flood the + // upstream connection with these. + if !uc.channels.Has(upstreamName) { + params := []string{upstreamName} + if key != "" { + params = append(params, key) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "JOIN", + Params: params, + }) + } + + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + // Don't clear the channel key if there's one set + // TODO: add a way to unset the channel key + if key != "" { + ch.Key = key + } + uc.network.attach(ctx, ch) + } else { + ch = &Channel{ + Name: upstreamName, + Key: key, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } + case "PART": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + for _, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if strings.EqualFold(reason, "detach") { + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + uc.network.detach(ch) + } else { + ch = &Channel{ + Name: name, + Detached: true, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } else { + params := []string{upstreamName} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "PART", + Params: params, + }) + + if err := uc.network.deleteChannel(ctx, upstreamName); err != nil { + dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) + } + } + } + case "KICK": + var channelStr, userStr string + if err := parseMessageParams(msg, &channelStr, &userStr); err != nil { + return err + } + + channels := strings.Split(channelStr, ",") + users := strings.Split(userStr, ",") + + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + + if len(channels) != 1 && len(channels) != len(users) { + return ircError{&irc.Message{ + Command: irc.ERR_BADCHANMASK, + Params: []string{dc.nick, channelStr, "Bad channel mask"}, + }} + } + + for i, user := range users { + var channel string + if len(channels) == 1 { + channel = channels[0] + } else { + channel = channels[i] + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + params := []string{upstreamChannel, upstreamUser} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "KICK", + Params: params, + }) + } + case "MODE": + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + + var modeStr string + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + if casemapASCII(name) == dc.nickCM { + if modeStr != "" { + if uc := dc.upstream(); uc != nil { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: []string{uc.nick, modeStr}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_UMODEUNKNOWNFLAG, + Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"}, + }) + } + } else { + var userMode string + if uc := dc.upstream(); uc != nil { + userMode = string(uc.modes) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + userMode}, + }) + } + return nil + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if !uc.isChannel(upstreamName) { + return ircError{&irc.Message{ + Command: irc.ERR_USERSDONTMATCH, + Params: []string{dc.nick, "Cannot change mode for other users"}, + }} + } + + if modeStr != "" { + params := []string{upstreamName, modeStr} + params = append(params, msg.Params[2:]...) + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: params, + }) + } else { + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "No such channel"}, + }} + } + + if ch.modes == nil { + // we haven't received the initial RPL_CHANNELMODEIS yet + // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS + return nil + } + + modeStr, modeParams := ch.modes.Format() + params := []string{dc.nick, name, modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + if ch.creationTime != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, name, ch.creationTime}, + }) + } + } + case "TOPIC": + var channel string + if err := parseMessageParams(msg, &channel); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + if len(msg.Params) > 1 { // setting topic + topic := msg.Params[1] + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "TOPIC", + Params: []string{upstreamName, topic}, + }) + } else { // getting topic + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, upstreamName, "No such channel"}, + }} + } + sendTopic(dc, ch) + } + case "LIST": + network := dc.network + if network == nil && len(msg.Params) > 0 { + var err error + network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0]) + if err != nil { + return err + } + } + if network == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"}, + }) + return nil + } + + uc := network.conn + if uc == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Disconnected from upstream server"}, + }) + return nil + } + + uc.enqueueCommand(dc, msg) + case "NAMES": + if len(msg.Params) == 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, "*", "End of /NAMES list"}, + }) + return nil + } + + channels := strings.Split(msg.Params[0], ",") + for _, channel := range channels { + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ch := uc.channels.Value(upstreamName) + if ch != nil { + sendNames(dc, ch) + } else { + // NAMES on a channel we have not joined, ask upstream + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NAMES", + Params: []string{upstreamName}, + }) + } + } + // For WHOX docs, see: + // - http://faerion.sourceforge.net/doc/irc/whox.var + // - https://github.com/quakenet/snircd/blob/master/doc/readme.who + // Note, many features aren't widely implemented, such as flags and mask2 + case "WHO": + if len(msg.Params) == 0 { + // TODO: support WHO without parameters + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, "*", "End of /WHO list"}, + }) + return nil + } + + // Clients will use the first mask to match RPL_ENDOFWHO + endOfWhoToken := msg.Params[0] + + // TODO: add support for WHOX mask2 + mask := msg.Params[0] + var options string + if len(msg.Params) > 1 { + options = msg.Params[1] + } + + optionsParts := strings.SplitN(options, "%", 2) + // TODO: add support for WHOX flags in optionsParts[0] + var fields, whoxToken string + if len(optionsParts) == 2 { + optionsParts := strings.SplitN(optionsParts[1], ",", 2) + fields = strings.ToLower(optionsParts[0]) + if len(optionsParts) == 2 && strings.Contains(fields, "t") { + whoxToken = optionsParts[1] + } + } + + // TODO: support mixed bouncer/upstream WHO queries + maskCM := casemapASCII(mask) + if dc.network == nil && maskCM == dc.nickCM { + // TODO: support AWAY (H/G) in self WHO reply + flags := "H" + if dc.user.Admin { + flags += "*" + } + info := whoxInfo{ + Token: whoxToken, + Username: dc.user.Username, + Hostname: dc.hostname, + Server: dc.srv.Config().Hostname, + Nickname: dc.nick, + Flags: flags, + Account: dc.user.Username, + Realname: dc.realname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + if maskCM == serviceNickCM { + info := whoxInfo{ + Token: whoxToken, + Username: servicePrefix.User, + Hostname: servicePrefix.Host, + Server: dc.srv.Config().Hostname, + Nickname: serviceNick, + Flags: "H*", + Account: serviceNick, + Realname: serviceRealname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + + // TODO: properly support WHO masks + uc, upstreamMask, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + params := []string{upstreamMask} + if options != "" { + params = append(params, options) + } + + uc.enqueueCommand(dc, &irc.Message{ + Command: "WHO", + Params: params, + }) + case "WHOIS": + if len(msg.Params) == 0 { + return ircError{&irc.Message{ + Command: irc.ERR_NONICKNAMEGIVEN, + Params: []string{dc.nick, "No nickname given"}, + }} + } + + var target, mask string + if len(msg.Params) == 1 { + target = "" + mask = msg.Params[0] + } else { + target = msg.Params[0] + mask = msg.Params[1] + } + // TODO: support multiple WHOIS users + if i := strings.IndexByte(mask, ','); i >= 0 { + mask = mask[:i] + } + + if dc.network == nil && casemapASCII(mask) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"}, + }) + if dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, dc.nick, "is a bouncer administrator"}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, dc.nick, "End of /WHOIS list"}, + }) + return nil + } + if casemapASCII(mask) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, serviceNick, "is the bouncer service"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, serviceNick, "End of /WHOIS list"}, + }) + return nil + } + + // TODO: support WHOIS masks + uc, upstreamNick, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + var params []string + if target != "" { + if target == mask { // WHOIS nick nick + params = []string{upstreamNick, upstreamNick} + } else { + params = []string{target, upstreamNick} + } + } else { + params = []string{upstreamNick} + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "WHOIS", + Params: params, + }) + case "PRIVMSG", "NOTICE": + var targetsStr, text string + if err := parseMessageParams(msg, &targetsStr, &text); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) { + // "$" means a server mask follows. If it's the bouncer's + // hostname, broadcast the message to all bouncer users. + if !dc.user.Admin { + return ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_BADMASK, + Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"}, + }} + } + + dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text) + + broadcastTags := tags.Copy() + broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now())) + broadcastMsg := &irc.Message{ + Tags: broadcastTags, + Prefix: servicePrefix, + Command: msg.Command, + Params: []string{name, text}, + } + dc.srv.forEachUser(func(u *user) { + u.events <- eventBroadcast{broadcastMsg} + }) + continue + } + + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + continue + } + + if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM { + if dc.caps["echo-message"] { + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + dc.SendMessage(&irc.Message{ + Tags: echoTags, + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + } + handleServicePRIVMSG(ctx, dc, text) + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" { + dc.handleNickServPRIVMSG(ctx, uc, text) + } + + unmarshaledText := text + if uc.isChannel(upstreamName) { + unmarshaledText = dc.unmarshalText(uc, text) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: msg.Command, + Params: []string{upstreamName, unmarshaledText}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: msg.Command, + Params: []string{upstreamName, text}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "TAGMSG": + var targetsStr string + if err := parseMessageParams(msg, &targetsStr); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: "TAGMSG", + Params: []string{name}, + }) + continue + } + + if casemapASCII(name) == serviceNickCM { + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + if _, ok := uc.caps["message-tags"]; !ok { + continue + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: "TAGMSG", + Params: []string{upstreamName}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: "TAGMSG", + Params: []string{upstreamName}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "INVITE": + var user, channel string + if err := parseMessageParams(msg, &user, &channel); err != nil { + return err + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "INVITE", + Params: []string{upstreamUser, upstreamChannel}, + }) + case "AUTHENTICATE": + // Post-connection-registration AUTHENTICATE is unsupported in + // multi-upstream mode, or if the upstream doesn't support SASL + uc := dc.upstream() + if uc == nil || !uc.caps["sasl"] { + return ircError{&irc.Message{ + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Upstream network authentication not supported"}, + }} + } + + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } + + if credentials != nil { + if uc.saslClient != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Another authentication attempt is already in progress"}, + }) + return nil + } + + uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername) + uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword) + uc.enqueueCommand(dc, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"PLAIN"}, + }) + } + case "REGISTER", "VERIFY": + // Check number of params here, since we'll use that to save the + // credentials on command success + if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) { + return newNeedMoreParamsError(msg.Command) + } + + uc := dc.upstream() + if uc == nil || !uc.caps["draft/account-registration"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, + }} + } + + uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0]) + uc.enqueueCommand(dc, msg) + case "MONITOR": + // MONITOR is unsupported in multi-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + if _, ok := uc.isupport["MONITOR"]; !ok { + return newUnknownCommandError(msg.Command) + } + + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "+", "-": + var targets string + if err := parseMessageParams(msg, nil, &targets); err != nil { + return err + } + for _, target := range strings.Split(targets, ",") { + if subcommand == "+" { + // Hard limit, just to avoid having downstreams fill our map + if len(dc.monitored.innerMap) >= 1000 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_MONLISTFULL, + Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"}, + }) + continue + } + + dc.monitored.SetValue(target, nil) + + if uc.monitored.Has(target) { + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } else { + dc.monitored.Delete(target) + } + } + uc.updateMonitor() + case "C": // clear + dc.monitored = newCasemapMap(0) + uc.updateMonitor() + case "L": // list + // TODO: be less lazy and pack the list + for _, entry := range dc.monitored.innerMap { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MONLIST, + Params: []string{dc.nick, entry.originalKey}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFMONLIST, + Params: []string{dc.nick, "End of MONITOR list"}, + }) + case "S": // status + // TODO: be less lazy and pack the lists + for _, entry := range dc.monitored.innerMap { + target := entry.originalKey + + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } + case "CHATHISTORY": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + var target, limitStr string + var boundsStr [2]string + switch subcommand { + case "AFTER", "BEFORE", "LATEST": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil { + return err + } + case "BETWEEN": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + case "TARGETS": + if dc.network == nil { + // Either an unbound bouncer network, in which case we should return no targets, + // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet. + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {}) + return nil + } + if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + default: + // TODO: support AROUND + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"}, + }} + } + + // We don't save history for our service + if casemapASCII(target) == serviceNickCM { + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {}) + return nil + } + + store, ok := dc.user.msgStore.(chatHistoryMessageStore) + if !ok { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{dc.nick, "CHATHISTORY", "Unknown command"}, + }} + } + + network, entity, err := dc.unmarshalEntityNetwork(target) + if err != nil { + return err + } + entity = network.casemap(entity) + + // TODO: support msgid criteria + var bounds [2]time.Time + bounds[0] = parseChatHistoryBound(boundsStr[0]) + if subcommand == "LATEST" && boundsStr[0] == "*" { + bounds[0] = time.Now() + } else if bounds[0].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"}, + }} + } + + if boundsStr[1] != "" { + bounds[1] = parseChatHistoryBound(boundsStr[1]) + if bounds[1].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"}, + }} + } + } + + limit, err := strconv.Atoi(limitStr) + if err != nil || limit < 0 || limit > chatHistoryLimit { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"}, + }} + } + + eventPlayback := dc.caps["draft/event-playback"] + + var history []*irc.Message + switch subcommand { + case "BEFORE", "LATEST": + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) + case "AFTER": + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) + case "BETWEEN": + if bounds[0].Before(bounds[1]) { + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } else { + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } + case "TARGETS": + // TODO: support TARGETS in multi-upstream mode + targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback) + if err != nil { + dc.logger.Printf("failed fetching targets for chathistory: %v", err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"}, + }} + } + + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) { + for _, target := range targets { + if ch := network.channels.Value(target.Name); ch != nil && ch.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "CHATHISTORY", + Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)}, + }) + } + }) + + return nil + } + if err != nil { + dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err) + return newChatHistoryError(subcommand, target) + } + + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, network)) + } + }) + case "READ": + var target, criteria string + if err := parseMessageParams(msg, &target); err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"}, + }} + } + if len(msg.Params) > 1 { + criteria = msg.Params[1] + } + + // We don't save read receipts for our service + if casemapASCII(target) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{target, "*"}, + }) + return nil + } + + uc, entity, err := dc.unmarshalEntity(target) + if err != nil { + return err + } + entityCM := uc.network.casemap(entity) + + r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } else if r == nil { + r = &ReadReceipt{ + Target: entityCM, + } + } + + broadcast := false + if len(criteria) > 0 { + // TODO: support msgid criteria + criteriaParts := strings.SplitN(criteria, "=", 2) + if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"}, + }} + } + + timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1]) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"}, + }} + } + now := time.Now() + if timestamp.After(now) { + timestamp = now + } + if r.Timestamp.Before(timestamp) { + r.Timestamp = timestamp + if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil { + dc.logger.Printf("failed to store receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } + broadcast = true + } + } + + timestampStr := "*" + if !r.Timestamp.IsZero() { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + uc.forEachDownstream(func(d *downstreamConn) { + if broadcast || dc.id == d.id { + d.SendMessage(&irc.Message{ + Prefix: d.prefix(), + Command: "READ", + Params: []string{d.marshalEntity(uc.network, entity), timestampStr}, + }) + } + }) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"}, + }} + case "LISTNETWORKS": + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + case "ADDNETWORK": + var attrsStr string + if err := parseMessageParams(msg, nil, &attrsStr); err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + record := &Network{Nick: dc.nick, Enabled: true} + if err := updateNetworkAttrs(record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)}, + }) + case "CHANGENETWORK": + var idStr, attrsStr string + if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + record := net.Network // copy network record because we'll mutate it + if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + _, err = dc.user.updateNetwork(ctx, &record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"CHANGENETWORK", idStr}, + }) + case "DELNETWORK": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"DELNETWORK", idStr}, + }) + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"}, + }} + } + default: + dc.logger.Printf("unhandled message: %v", msg) + + // Only forward unknown commands in single-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + + uc.SendMessageLabeled(ctx, dc.id, msg) + } + return nil +} + +func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) { + username, password, ok := parseNickServCredentials(text, uc.nick) + if ok { + uc.network.autoSaveSASLPlain(ctx, username, password) + } +} + +func parseNickServCredentials(text, nick string) (username, password string, ok bool) { + fields := strings.Fields(text) + if len(fields) < 2 { + return "", "", false + } + cmd := strings.ToUpper(fields[0]) + params := fields[1:] + switch cmd { + case "REGISTER": + username = nick + password = params[0] + case "IDENTIFY": + if len(params) == 1 { + username = nick + password = params[0] + } else { + username = params[0] + password = params[1] + } + case "SET": + if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") { + username = nick + password = params[1] + } + default: + return "", "", false + } + return username, password, true +} diff --git a/branches/origin/go.mod b/branches/origin/go.mod new file mode 100644 index 0000000..cc08167 --- /dev/null +++ b/branches/origin/go.mod @@ -0,0 +1,39 @@ +module marisa.chaotic.ninja/suika + +go 1.20 + +require ( + git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 + git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 + github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead + github.com/lib/pq v1.10.7 + golang.org/x/crypto v0.7.0 + golang.org/x/term v0.6.0 + golang.org/x/time v0.3.0 + gopkg.in/irc.v3 v3.1.4 + modernc.org/sqlite v1.21.0 +) + +require ( + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/stretchr/testify v1.8.0 // indirect + golang.org/x/mod v0.3.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + lukechampine.com/uint128 v1.2.0 // indirect + modernc.org/cc/v3 v3.40.0 // indirect + modernc.org/ccgo/v3 v3.16.13 // indirect + modernc.org/libc v1.22.3 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/opt v0.1.3 // indirect + modernc.org/strutil v1.1.3 // indirect + modernc.org/token v1.0.1 // indirect +) diff --git a/branches/origin/go.sum b/branches/origin/go.sum new file mode 100644 index 0000000..090246d --- /dev/null +++ b/branches/origin/go.sum @@ -0,0 +1,105 @@ +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao= +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U= +git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc= +gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= +modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= +modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= +modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= +modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= +modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= +modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY= +modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow= +modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI= +modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= +modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= +modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws= +modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= +modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= diff --git a/branches/origin/irc.go b/branches/origin/irc.go new file mode 100644 index 0000000..1799e32 --- /dev/null +++ b/branches/origin/irc.go @@ -0,0 +1,811 @@ +package suika + +import ( + "fmt" + "sort" + "strings" + "time" + "unicode" + "unicode/utf8" + + "gopkg.in/irc.v3" +) + +const ( + rpl_statsping = "246" + rpl_localusers = "265" + rpl_globalusers = "266" + rpl_creationtime = "329" + rpl_topicwhotime = "333" + rpl_whospcrpl = "354" + rpl_whoisaccount = "330" + err_invalidcapcmd = "410" +) + +const ( + maxMessageLength = 512 + maxMessageParams = 15 + maxSASLLength = 400 +) + +// The server-time layout, as defined in the IRCv3 spec. +const serverTimeLayout = "2006-01-02T15:04:05.000Z" + +func formatServerTime(t time.Time) string { + return t.UTC().Format(serverTimeLayout) +} + +type userModes string + +func (ms userModes) Has(c byte) bool { + return strings.IndexByte(string(ms), c) >= 0 +} + +func (ms *userModes) Add(c byte) { + if !ms.Has(c) { + *ms += userModes(c) + } +} + +func (ms *userModes) Del(c byte) { + i := strings.IndexByte(string(*ms), c) + if i >= 0 { + *ms = (*ms)[:i] + (*ms)[i+1:] + } +} + +func (ms *userModes) Apply(s string) error { + var plusMinus byte + for i := 0; i < len(s); i++ { + switch c := s[i]; c { + case '+', '-': + plusMinus = c + default: + switch plusMinus { + case '+': + ms.Add(c) + case '-': + ms.Del(c) + default: + return fmt.Errorf("malformed modestring %q: missing plus/minus", s) + } + } + } + return nil +} + +type channelModeType byte + +// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message +const ( + // modes that add or remove an address to or from a list + modeTypeA channelModeType = iota + // modes that change a setting on a channel, and must always have a parameter + modeTypeB + // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset + modeTypeC + // modes that change a setting on a channel, and must not have a parameter + modeTypeD +) + +var stdChannelModes = map[byte]channelModeType{ + 'b': modeTypeA, // ban list + 'e': modeTypeA, // ban exception list + 'I': modeTypeA, // invite exception list + 'k': modeTypeB, // channel key + 'l': modeTypeC, // channel user limit + 'i': modeTypeD, // channel is invite-only + 'm': modeTypeD, // channel is moderated + 'n': modeTypeD, // channel has no external messages + 's': modeTypeD, // channel is secret + 't': modeTypeD, // channel has protected topic +} + +type channelModes map[byte]string + +// applyChannelModes parses a mode string and mode arguments from a MODE message, +// and applies the corresponding channel mode and user membership changes on that channel. +// +// If ch.modes is nil, channel modes are not updated. +// +// needMarshaling is a list of indexes of mode arguments that represent entities +// that must be marshaled when sent downstream. +func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) { + needMarshaling = make(map[int]struct{}, len(arguments)) + nextArgument := 0 + var plusMinus byte +outer: + for i := 0; i < len(modeStr); i++ { + mode := modeStr[i] + if mode == '+' || mode == '-' { + plusMinus = mode + continue + } + if plusMinus != '+' && plusMinus != '-' { + return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr) + } + + for _, membership := range ch.conn.availableMemberships { + if membership.Mode == mode { + if nextArgument >= len(arguments) { + return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) + } + member := arguments[nextArgument] + m := ch.Members.Value(member) + if m != nil { + if plusMinus == '+' { + m.Add(ch.conn.availableMemberships, membership) + } else { + // TODO: for upstreams without multi-prefix, query the user modes again + m.Remove(membership) + } + } + needMarshaling[nextArgument] = struct{}{} + nextArgument++ + continue outer + } + } + + mt, ok := ch.conn.availableChannelModes[mode] + if !ok { + continue + } + if mt == modeTypeA { + nextArgument++ + } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') { + if plusMinus == '+' { + var argument string + // some sentitive arguments (such as channel keys) can be omitted for privacy + // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages) + if nextArgument < len(arguments) { + argument = arguments[nextArgument] + } + if ch.modes != nil { + ch.modes[mode] = argument + } + } else { + delete(ch.modes, mode) + } + nextArgument++ + } else if mt == modeTypeC || mt == modeTypeD { + if plusMinus == '+' { + if ch.modes != nil { + ch.modes[mode] = "" + } + } else { + delete(ch.modes, mode) + } + } + } + return needMarshaling, nil +} + +func (cm channelModes) Format() (modeString string, parameters []string) { + var modesWithValues strings.Builder + var modesWithoutValues strings.Builder + parameters = make([]string, 0, 16) + for mode, value := range cm { + if value != "" { + modesWithValues.WriteString(string(mode)) + parameters = append(parameters, value) + } else { + modesWithoutValues.WriteString(string(mode)) + } + } + modeString = "+" + modesWithValues.String() + modesWithoutValues.String() + return +} + +const stdChannelTypes = "#&+!" + +type channelStatus byte + +const ( + channelPublic channelStatus = '=' + channelSecret channelStatus = '@' + channelPrivate channelStatus = '*' +) + +func parseChannelStatus(s string) (channelStatus, error) { + if len(s) > 1 { + return 0, fmt.Errorf("invalid channel status %q: more than one character", s) + } + switch cs := channelStatus(s[0]); cs { + case channelPublic, channelSecret, channelPrivate: + return cs, nil + default: + return 0, fmt.Errorf("invalid channel status %q: unknown status", s) + } +} + +type membership struct { + Mode byte + Prefix byte +} + +var stdMemberships = []membership{ + {'q', '~'}, // founder + {'a', '&'}, // protected + {'o', '@'}, // operator + {'h', '%'}, // halfop + {'v', '+'}, // voice +} + +// memberships always sorted by descending membership rank +type memberships []membership + +func (m *memberships) Add(availableMemberships []membership, newMembership membership) { + l := *m + i := 0 + for _, availableMembership := range availableMemberships { + if i >= len(l) { + break + } + if l[i] == availableMembership { + if availableMembership == newMembership { + // we already have this membership + return + } + i++ + continue + } + if availableMembership == newMembership { + break + } + } + // insert newMembership at i + l = append(l, membership{}) + copy(l[i+1:], l[i:]) + l[i] = newMembership + *m = l +} + +func (m *memberships) Remove(oldMembership membership) { + l := *m + for i, currentMembership := range l { + if currentMembership == oldMembership { + *m = append(l[:i], l[i+1:]...) + return + } + } +} + +func (m memberships) Format(dc *downstreamConn) string { + if !dc.caps["multi-prefix"] { + if len(m) == 0 { + return "" + } + return string(m[0].Prefix) + } + prefixes := make([]byte, len(m)) + for i, membership := range m { + prefixes[i] = membership.Prefix + } + return string(prefixes) +} + +func parseMessageParams(msg *irc.Message, out ...*string) error { + if len(msg.Params) < len(out) { + return newNeedMoreParamsError(msg.Command) + } + for i := range out { + if out[i] != nil { + *out[i] = msg.Params[i] + } + } + return nil +} + +func copyClientTags(tags irc.Tags) irc.Tags { + t := make(irc.Tags, len(tags)) + for k, v := range tags { + if strings.HasPrefix(k, "+") { + t[k] = v + } + } + return t +} + +type batch struct { + Type string + Params []string + Outer *batch // if not-nil, this batch is nested in Outer + Label string +} + +func join(channels, keys []string) []*irc.Message { + // Put channels with a key first + js := joinSorter{channels, keys} + sort.Sort(&js) + + // Two spaces because there are three words (JOIN, channels and keys) + maxLength := maxMessageLength - (len("JOIN") + 2) + + var msgs []*irc.Message + var channelsBuf, keysBuf strings.Builder + for i, channel := range channels { + key := keys[i] + + n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel) + if key != "" { + n += 1 + len(key) + } + + if channelsBuf.Len() > 0 && n > maxLength { + // No room for the new channel in this message + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + channelsBuf.Reset() + keysBuf.Reset() + } + + if channelsBuf.Len() > 0 { + channelsBuf.WriteByte(',') + } + channelsBuf.WriteString(channel) + if key != "" { + if keysBuf.Len() > 0 { + keysBuf.WriteByte(',') + } + keysBuf.WriteString(key) + } + } + if channelsBuf.Len() > 0 { + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + } + + return msgs +} + +func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message { + maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text + + var msgs []*irc.Message + for len(tokens) > 0 { + var msgTokens []string + if len(tokens) > maxTokens { + msgTokens = tokens[:maxTokens] + tokens = tokens[maxTokens:] + } else { + msgTokens = tokens + tokens = nil + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ISUPPORT, + Params: append(append([]string{nick}, msgTokens...), "are supported"), + }) + } + + return msgs +} + +func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message { + var msgs []*irc.Message + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTDSTART, + Params: []string{nick, fmt.Sprintf("- Message of the Day -")}, + }) + + for _, l := range strings.Split(motd, "\n") { + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTD, + Params: []string{nick, l}, + }) + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ENDOFMOTD, + Params: []string{nick, "End of /MOTD command."}, + }) + + return msgs +} + +func generateMonitor(subcmd string, targets []string) []*irc.Message { + maxLength := maxMessageLength - len("MONITOR "+subcmd+" ") + + var msgs []*irc.Message + var buf []string + n := 0 + for _, target := range targets { + if n+len(target)+1 > maxLength { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + buf = buf[:0] + n = 0 + } + + buf = append(buf, target) + n += len(target) + 1 + } + + if len(buf) > 0 { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + } + + return msgs +} + +type joinSorter struct { + channels []string + keys []string +} + +func (js *joinSorter) Len() int { + return len(js.channels) +} + +func (js *joinSorter) Less(i, j int) bool { + if (js.keys[i] != "") != (js.keys[j] != "") { + // Only one of the channels has a key + return js.keys[i] != "" + } + return js.channels[i] < js.channels[j] +} + +func (js *joinSorter) Swap(i, j int) { + js.channels[i], js.channels[j] = js.channels[j], js.channels[i] + js.keys[i], js.keys[j] = js.keys[j], js.keys[i] +} + +// parseCTCPMessage parses a CTCP message. CTCP is defined in +// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02 +func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) { + if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 { + return "", "", false + } + text := msg.Params[1] + + if !strings.HasPrefix(text, "\x01") { + return "", "", false + } + text = strings.Trim(text, "\x01") + + words := strings.SplitN(text, " ", 2) + cmd = strings.ToUpper(words[0]) + if len(words) > 1 { + params = words[1] + } + + return cmd, params, true +} + +type casemapping func(string) string + +func casemapNone(name string) string { + return name +} + +// CasemapASCII of name is the canonical representation of name according to the +// ascii casemapping. +func casemapASCII(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } + } + return string(nameBytes) +} + +// casemapRFC1459 of name is the canonical representation of name according to the +// rfc1459 casemapping. +func casemapRFC1459(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } else if r == '~' { + nameBytes[i] = '^' + } + } + return string(nameBytes) +} + +// casemapRFC1459Strict of name is the canonical representation of name +// according to the rfc1459-strict casemapping. +func casemapRFC1459Strict(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } + } + return string(nameBytes) +} + +func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) { + switch tokenValue { + case "ascii": + casemap = casemapASCII + case "rfc1459": + casemap = casemapRFC1459 + case "rfc1459-strict": + casemap = casemapRFC1459Strict + default: + return nil, false + } + return casemap, true +} + +func partialCasemap(higher casemapping, name string) string { + nameFullyCM := []byte(higher(name)) + nameBytes := []byte(name) + for i, r := range nameBytes { + if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') { + nameBytes[i] = nameFullyCM[i] + } + } + return string(nameBytes) +} + +type casemapMap struct { + innerMap map[string]casemapEntry + casemap casemapping +} + +type casemapEntry struct { + originalKey string + value interface{} +} + +func newCasemapMap(size int) casemapMap { + return casemapMap{ + innerMap: make(map[string]casemapEntry, size), + casemap: casemapNone, + } +} + +func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return "", false + } + return entry.originalKey, true +} + +func (cm *casemapMap) Has(name string) bool { + _, ok := cm.innerMap[cm.casemap(name)] + return ok +} + +func (cm *casemapMap) Len() int { + return len(cm.innerMap) +} + +func (cm *casemapMap) SetValue(name string, value interface{}) { + nameCM := cm.casemap(name) + entry, ok := cm.innerMap[nameCM] + if !ok { + cm.innerMap[nameCM] = casemapEntry{ + originalKey: name, + value: value, + } + return + } + entry.value = value + cm.innerMap[nameCM] = entry +} + +func (cm *casemapMap) Delete(name string) { + delete(cm.innerMap, cm.casemap(name)) +} + +func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { + cm.casemap = newCasemap + newInnerMap := make(map[string]casemapEntry, len(cm.innerMap)) + for _, entry := range cm.innerMap { + newInnerMap[cm.casemap(entry.originalKey)] = entry + } + cm.innerMap = newInnerMap +} + +type upstreamChannelCasemapMap struct{ casemapMap } + +func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*upstreamChannel) +} + +type channelCasemapMap struct{ casemapMap } + +func (cm *channelCasemapMap) Value(name string) *Channel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*Channel) +} + +type membershipsCasemapMap struct{ casemapMap } + +func (cm *membershipsCasemapMap) Value(name string) *memberships { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*memberships) +} + +type deliveredCasemapMap struct{ casemapMap } + +func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(deliveredClientMap) +} + +type monitorCasemapMap struct{ casemapMap } + +func (cm *monitorCasemapMap) Value(name string) (online bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return false + } + return entry.value.(bool) +} + +func isWordBoundary(r rune) bool { + switch r { + case '-', '_', '|': // inspired from weechat.look.highlight_regex + return false + default: + return !unicode.IsLetter(r) && !unicode.IsNumber(r) + } +} + +func isHighlight(text, nick string) bool { + for { + i := strings.Index(text, nick) + if i < 0 { + return false + } + + left, _ := utf8.DecodeLastRuneInString(text[:i]) + right, _ := utf8.DecodeRuneInString(text[i+len(nick):]) + if isWordBoundary(left) && isWordBoundary(right) { + return true + } + + text = text[i+len(nick):] + } +} + +// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound. +// The zero time is returned on error. +func parseChatHistoryBound(param string) time.Time { + parts := strings.SplitN(param, "=", 2) + if len(parts) != 2 { + return time.Time{} + } + switch parts[0] { + case "timestamp": + timestamp, err := time.Parse(serverTimeLayout, parts[1]) + if err != nil { + return time.Time{} + } + return timestamp + default: + return time.Time{} + } +} + +// whoxFields is the list of all WHOX field letters, by order of appearance in +// RPL_WHOSPCRPL messages. +var whoxFields = []byte("tcuihsnfdlaor") + +type whoxInfo struct { + Token string + Username string + Hostname string + Server string + Nickname string + Flags string + Account string + Realname string +} + +func (info *whoxInfo) get(field byte) string { + switch field { + case 't': + return info.Token + case 'c': + return "*" + case 'u': + return info.Username + case 'i': + return "255.255.255.255" + case 'h': + return info.Hostname + case 's': + return info.Server + case 'n': + return info.Nickname + case 'f': + return info.Flags + case 'd': + return "0" + case 'l': // idle time + return "0" + case 'a': + account := "0" // WHOX uses "0" to mean "no account" + if info.Account != "" && info.Account != "*" { + account = info.Account + } + return account + case 'o': + return "0" + case 'r': + return info.Realname + } + return "" +} + +func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message { + if fields == "" { + return &irc.Message{ + Prefix: prefix, + Command: irc.RPL_WHOREPLY, + Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname}, + } + } + + fieldSet := make(map[byte]bool) + for i := 0; i < len(fields); i++ { + fieldSet[fields[i]] = true + } + + var values []string + for _, field := range whoxFields { + if !fieldSet[field] { + continue + } + values = append(values, info.get(field)) + } + + return &irc.Message{ + Prefix: prefix, + Command: rpl_whospcrpl, + Params: append([]string{nick}, values...), + } +} + +var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C") + +func encodeISUPPORT(s string) string { + return isupportEncoder.Replace(s) +} diff --git a/branches/origin/irc_test.go b/branches/origin/irc_test.go new file mode 100644 index 0000000..8757bc8 --- /dev/null +++ b/branches/origin/irc_test.go @@ -0,0 +1,34 @@ +package suika + +import ( + "testing" +) + +func TestIsHighlight(t *testing.T) { + nick := "SojuUser" + testCases := []struct { + name string + text string + hl bool + }{ + {"noContains", "hi there Soju User!", false}, + {"middle", "hi there SojuUser!", true}, + {"start", "SojuUser: how are you doing?", true}, + {"end", "maybe ask SojuUser", true}, + {"inWord", "but OtherSojuUserSan is a different nick", false}, + {"startWord", "and OtherSojuUser is another different nick", false}, + {"endWord", "and SojuUserSan is yet a different nick", false}, + {"underscore", "and SojuUser_san has nothing to do with me", false}, + {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false}, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + hl := isHighlight(tc.text, nick) + if hl != tc.hl { + t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl) + } + }) + } +} diff --git a/branches/origin/msgstore.go b/branches/origin/msgstore.go new file mode 100644 index 0000000..ed57429 --- /dev/null +++ b/branches/origin/msgstore.go @@ -0,0 +1,123 @@ +package suika + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +// messageStore is a per-user store for IRC messages. +type messageStore interface { + Close() error + // LastMsgID queries the last message ID for the given network, entity and + // date. The message ID returned may not refer to a valid message, but can be + // used in history queries. + LastMsgID(network *Network, entity string, t time.Time) (string, error) + // LoadLatestID queries the latest non-event messages for the given network, + // entity and date, up to a count of limit messages, sorted from oldest to newest. + LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) + Append(network *Network, entity string, msg *irc.Message) (id string, err error) +} + +type chatHistoryTarget struct { + Name string + LatestMessage time.Time +} + +// chatHistoryMessageStore is a message store that supports chat history +// operations. +type chatHistoryMessageStore interface { + messageStore + + // ListTargets lists channels and nicknames by time of the latest message. + // It returns up to limit targets, starting from start and ending on end, + // both excluded. end may be before or after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) + // LoadBeforeTime loads up to limit messages before start down to end. The + // returned messages must be between and excluding the provided bounds. + // end is before start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + // LoadBeforeTime loads up to limit messages after start up to end. The + // returned messages must be between and excluding the provided bounds. + // end is after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) +} + +type msgIDType uint + +const ( + msgIDNone msgIDType = iota + msgIDMemory + msgIDFS +) + +const msgIDVersion uint = 0 + +type msgIDHeader struct { + Version uint + Network bare.Int + Target string + Type msgIDType +} + +type msgIDBody interface { + msgIDType() msgIDType +} + +func formatMsgID(netID int64, target string, body msgIDBody) string { + var buf bytes.Buffer + w := bare.NewWriter(&buf) + + header := msgIDHeader{ + Version: msgIDVersion, + Network: bare.Int(netID), + Target: target, + Type: body.msgIDType(), + } + if err := bare.MarshalWriter(w, &header); err != nil { + panic(err) + } + if err := bare.MarshalWriter(w, body); err != nil { + panic(err) + } + return base64.RawURLEncoding.EncodeToString(buf.Bytes()) +} + +func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + r := bare.NewReader(bytes.NewReader(b)) + + var header msgIDHeader + if err := bare.UnmarshalBareReader(r, &header); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + if header.Version != msgIDVersion { + return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion) + } + + if body != nil { + typ := body.msgIDType() + if header.Type != typ { + return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ) + } + + if err := bare.UnmarshalBareReader(r, body); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + } + + return int64(header.Network), header.Target, nil +} diff --git a/branches/origin/msgstore_fs.go b/branches/origin/msgstore_fs.go new file mode 100644 index 0000000..90bd76a --- /dev/null +++ b/branches/origin/msgstore_fs.go @@ -0,0 +1,698 @@ +package suika + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const ( + fsMessageStoreMaxFiles = 20 + fsMessageStoreMaxTries = 100 +) + +func escapeFilename(unsafe string) (safe string) { + if unsafe == "." { + return "-" + } else if unsafe == ".." { + return "--" + } else { + return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe) + } +} + +type date struct { + Year, Month, Day int +} + +func newDate(t time.Time) date { + year, month, day := t.Date() + return date{year, int(month), day} +} + +func (d date) Time() time.Time { + return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local) +} + +type fsMsgID struct { + Date date + Offset bare.Int +} + +func (fsMsgID) msgIDType() msgIDType { + return msgIDFS +} + +func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) { + var id fsMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", time.Time{}, 0, err + } + return netID, entity, id.Date.Time(), int64(id.Offset), nil +} + +func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string { + id := fsMsgID{ + Date: newDate(t), + Offset: bare.Int(offset), + } + return formatMsgID(netID, entity, &id) +} + +type fsMessageStoreFile struct { + *os.File + lastUse time.Time +} + +// fsMessageStore is a per-user on-disk store for IRC messages. +// +// It mimicks the ZNC log layout and format. See the ZNC source: +// https://github.com/znc/znc/blob/master/modules/log.cpp +type fsMessageStore struct { + root string + user *User + + // Write-only files used by Append + files map[string]*fsMessageStoreFile // indexed by entity +} + +var _ messageStore = (*fsMessageStore)(nil) +var _ chatHistoryMessageStore = (*fsMessageStore)(nil) + +func newFSMessageStore(root string, user *User) *fsMessageStore { + return &fsMessageStore{ + root: filepath.Join(root, escapeFilename(user.Username)), + user: user, + files: make(map[string]*fsMessageStoreFile), + } +} + +func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string { + year, month, day := t.Date() + filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) + return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) +} + +// nextMsgID queries the message ID for the next message to be written to f. +func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) { + offset, err := f.Seek(0, io.SeekEnd) + if err != nil { + return "", fmt.Errorf("failed to query next FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, offset), nil +} + +func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + p := ms.logPath(network, entity, t) + fi, err := os.Stat(p) + if os.IsNotExist(err) { + return formatFSMsgID(network.ID, entity, t, -1), nil + } else if err != nil { + return "", fmt.Errorf("failed to query last FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil +} + +func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + s := formatMessage(msg) + if s == "" { + return "", nil + } + + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(serverTimeLayout, string(tag)) + if err != nil { + return "", fmt.Errorf("failed to parse message time tag: %v", err) + } + t = t.In(time.Local) + } else { + t = time.Now() + } + + f := ms.files[entity] + + // TODO: handle non-monotonic clock behaviour + path := ms.logPath(network, entity, t) + if f == nil || f.Name() != path { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0750); err != nil { + return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err) + } + + ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640) + if err != nil { + return "", fmt.Errorf("failed to open message log file %q: %v", path, err) + } + + if f != nil { + f.Close() + } + f = &fsMessageStoreFile{File: ff} + ms.files[entity] = f + } + + f.lastUse = time.Now() + + if len(ms.files) > fsMessageStoreMaxFiles { + entities := make([]string, 0, len(ms.files)) + for name := range ms.files { + entities = append(entities, name) + } + sort.Slice(entities, func(i, j int) bool { + a, b := entities[i], entities[j] + return ms.files[a].lastUse.Before(ms.files[b].lastUse) + }) + entities = entities[0 : len(entities)-fsMessageStoreMaxFiles] + for _, name := range entities { + ms.files[name].Close() + delete(ms.files, name) + } + } + + msgID, err := nextFSMsgID(network, entity, t, f.File) + if err != nil { + return "", fmt.Errorf("failed to generate message ID: %v", err) + } + + _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s) + if err != nil { + return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err) + } + + return msgID, nil +} + +func (ms *fsMessageStore) Close() error { + var closeErr error + for _, f := range ms.files { + if err := f.Close(); err != nil { + closeErr = fmt.Errorf("failed to close message store: %v", err) + } + } + return closeErr +} + +// formatMessage formats a message log line. It assumes a well-formed IRC +// message. +func formatMessage(msg *irc.Message) string { + switch strings.ToUpper(msg.Command) { + case "NICK": + return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0]) + case "JOIN": + return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host) + case "PART": + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "KICK": + nick := msg.Params[1] + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason) + case "QUIT": + var reason string + if len(msg.Params) > 0 { + reason = msg.Params[0] + } + return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "TOPIC": + var topic string + if len(msg.Params) > 1 { + topic = msg.Params[1] + } + return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic) + case "MODE": + return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " ")) + case "NOTICE": + return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1]) + case "PRIVMSG": + if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" { + return fmt.Sprintf("* %s %s", msg.Prefix.Name, params) + } else { + return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1]) + } + default: + return "" + } +} + +func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { + var hour, minute, second int + _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) + if err != nil { + return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err) + } + line = line[11:] + + var cmd string + var prefix *irc.Prefix + var params []string + if events && strings.HasPrefix(line, "*** ") { + parts := strings.SplitN(line[4:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + switch parts[0] { + case "Joins:", "Parts:", "Quits:": + args := strings.SplitN(parts[1], " ", 3) + if len(args) < 2 { + return nil, time.Time{}, nil + } + nick := args[0] + mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + maskParts := strings.SplitN(mask, "@", 2) + if len(maskParts) != 2 { + return nil, time.Time{}, nil + } + prefix = &irc.Prefix{ + Name: nick, + User: maskParts[0], + Host: maskParts[1], + } + var reason string + if len(args) > 2 { + reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")") + } + switch parts[0] { + case "Joins:": + cmd = "JOIN" + params = []string{entity} + case "Parts:": + cmd = "PART" + if reason != "" { + params = []string{entity, reason} + } else { + params = []string{entity} + } + case "Quits:": + cmd = "QUIT" + if reason != "" { + params = []string{reason} + } + } + default: + nick := parts[0] + rem := parts[1] + if r := strings.TrimPrefix(rem, "is now known as "); r != rem { + cmd = "NICK" + prefix = &irc.Prefix{ + Name: nick, + } + params = []string{r} + } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem { + args := strings.SplitN(r, " ", 2) + if len(args) != 2 { + return nil, time.Time{}, nil + } + cmd = "KICK" + prefix = &irc.Prefix{ + Name: args[0], + } + reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + params = []string{entity, nick} + if reason != "" { + params = append(params, reason) + } + } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem { + cmd = "TOPIC" + prefix = &irc.Prefix{ + Name: nick, + } + topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'") + params = []string{entity, topic} + } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem { + cmd = "MODE" + prefix = &irc.Prefix{ + Name: nick, + } + params = append([]string{entity}, strings.Split(r, " ")...) + } else { + return nil, time.Time{}, nil + } + } + } else { + var sender, text string + if strings.HasPrefix(line, "<") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[1:], "> ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "-") { + cmd = "NOTICE" + parts := strings.SplitN(line[1:], "- ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "* ") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[2:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01" + } else { + return nil, time.Time{}, nil + } + + prefix = &irc.Prefix{Name: sender} + if entity == sender { + // This is a direct message from a user to us. We don't store own + // our nickname in the logs, so grab it from the network settings. + // Not very accurate since this may not match our nick at the time + // the message was received, but we can't do a lot better. + entity = GetNick(ms.user, network) + } + params = []string{entity, text} + } + + year, month, day := ref.Date() + t := time.Date(year, month, day, hour, minute, second, 0, time.Local) + + msg := &irc.Message{ + Tags: map[string]irc.TagValue{ + "time": irc.TagValue(formatServerTime(t)), + }, + Prefix: prefix, + Command: cmd, + Params: params, + } + return msg, t, nil +} + +func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages before ref: %v", err) + } + defer f.Close() + + historyRing := make([]*irc.Message, limit) + cur := 0 + + sc := bufio.NewScanner(f) + + if afterOffset >= 0 { + if _, err := f.Seek(afterOffset, io.SeekStart); err != nil { + return nil, nil + } + sc.Scan() // skip till next newline + } + + for sc.Scan() { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(end) { + continue + } else if !t.Before(ref) { + break + } + + historyRing[cur%limit] = msg + cur++ + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err()) + } + + n := limit + if cur < limit { + n = cur + } + start := (cur - n + limit) % limit + + if start+n <= limit { // ring doesnt wrap + return historyRing[start : start+n], nil + } else { // ring wraps + history := make([]*irc.Message, n) + r := copy(history, historyRing[start:]) + copy(history[r:], historyRing[:n-r]) + return history, nil + } +} + +func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages after ref: %v", err) + } + defer f.Close() + + var history []*irc.Message + sc := bufio.NewScanner(f) + for sc.Scan() && len(history) < limit { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(ref) { + continue + } else if !t.Before(end) { + break + } + + history = append(history, msg) + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err()) + } + + return history, nil +} + +func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + history := make([]*irc.Message, limit) + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) { + buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + var history []*irc.Message + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) { + buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + history = append(history, buf...) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location()) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + return history, nil +} + +func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + var afterTime time.Time + var afterOffset int64 + if id != "" { + var idNet int64 + var idEntity string + var err error + idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id) + if err != nil { + return nil, err + } + if idNet != network.ID || idEntity != entity { + return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity") + } + } + + history := make([]*irc.Message, limit) + t := time.Now() + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) { + var offset int64 = -1 + if afterOffset >= 0 && truncateDay(t).Equal(afterTime) { + offset = afterOffset + } + + buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := t.Date() + t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { + start = start.In(time.Local) + end = end.In(time.Local) + rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) + root, err := os.Open(rootPath) + if os.IsNotExist(err) { + return nil, nil + } else if err != nil { + return nil, err + } + + // The returned targets are escaped, and there is no way to un-escape + // TODO: switch to ReadDir (Go 1.16+) + targetNames, err := root.Readdirnames(0) + root.Close() + if err != nil { + return nil, err + } + + var targets []chatHistoryTarget + for _, target := range targetNames { + // target is already escaped here + targetPath := filepath.Join(rootPath, target) + targetDir, err := os.Open(targetPath) + if err != nil { + return nil, err + } + + entries, err := targetDir.Readdir(0) + targetDir.Close() + if err != nil { + return nil, err + } + + // We use mtime here, which may give imprecise or incorrect results + var t time.Time + for _, entry := range entries { + if entry.ModTime().After(t) { + t = entry.ModTime() + } + } + + // The timestamps we get from logs have second granularity + t = truncateSecond(t) + + // Filter out targets that don't fullfil the time bounds + if !isTimeBetween(t, start, end) { + continue + } + + targets = append(targets, chatHistoryTarget{ + Name: target, + LatestMessage: t, + }) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + // Sort targets by latest message time, backwards or forwards depending on + // the order of the time bounds + sort.Slice(targets, func(i, j int) bool { + t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage + if start.Before(end) { + return t1.Before(t2) + } else { + return !t1.Before(t2) + } + }) + + // Truncate the result if necessary + if len(targets) > limit { + targets = targets[:limit] + } + + return targets, nil +} + +func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error { + oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) + newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) + // Avoid loosing data by overwriting an existing directory + if _, err := os.Stat(newDir); err == nil { + return fmt.Errorf("destination %q already exists", newDir) + } + return os.Rename(oldDir, newDir) +} + +func truncateDay(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, 0, 0, 0, 0, t.Location()) +} + +func truncateSecond(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location()) +} + +func isTimeBetween(t, start, end time.Time) bool { + if end.Before(start) { + end, start = start, end + } + return start.Before(t) && t.Before(end) +} diff --git a/branches/origin/msgstore_memory.go b/branches/origin/msgstore_memory.go new file mode 100644 index 0000000..25ab953 --- /dev/null +++ b/branches/origin/msgstore_memory.go @@ -0,0 +1,160 @@ +package suika + +import ( + "context" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const messageRingBufferCap = 4096 + +type memoryMsgID struct { + Seq bare.Uint +} + +func (memoryMsgID) msgIDType() msgIDType { + return msgIDMemory +} + +func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) { + var id memoryMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", 0, err + } + return netID, entity, uint64(id.Seq), nil +} + +func formatMemoryMsgID(netID int64, entity string, seq uint64) string { + id := memoryMsgID{bare.Uint(seq)} + return formatMsgID(netID, entity, &id) +} + +type ringBufferKey struct { + networkID int64 + entity string +} + +type memoryMessageStore struct { + buffers map[ringBufferKey]*messageRingBuffer +} + +var _ messageStore = (*memoryMessageStore)(nil) + +func newMemoryMessageStore() *memoryMessageStore { + return &memoryMessageStore{ + buffers: make(map[ringBufferKey]*messageRingBuffer), + } +} + +func (ms *memoryMessageStore) Close() error { + ms.buffers = nil + return nil +} + +func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer { + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + return rb + } + rb := newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + return rb +} + +func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + var seq uint64 + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + seq = rb.cur + } + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + switch msg.Command { + case "PRIVMSG", "NOTICE": + // Only append these messages, because LoadLatestID shouldn't return + // other kinds of message. + default: + return "", nil + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + rb = newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + } + + seq := rb.Append(msg) + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + _, _, seq, err := parseMemoryMsgID(id) + if err != nil { + return nil, err + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + return nil, nil + } + + return rb.LoadLatestSeq(seq, limit) +} + +type messageRingBuffer struct { + buf []*irc.Message + cur uint64 +} + +func newMessageRingBuffer(capacity int) *messageRingBuffer { + return &messageRingBuffer{ + buf: make([]*irc.Message, capacity), + cur: 1, + } +} + +func (rb *messageRingBuffer) cap() uint64 { + return uint64(len(rb.buf)) +} + +func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 { + seq := rb.cur + i := int(seq % rb.cap()) + rb.buf[i] = msg + rb.cur++ + return seq +} + +func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) { + if seq > rb.cur { + return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur) + } else if seq == rb.cur { + return nil, nil + } + + // The query excludes the message with the sequence number seq + diff := rb.cur - seq - 1 + if diff > rb.cap() { + // We dropped diff - cap entries + diff = rb.cap() + } + if int(diff) > limit { + diff = uint64(limit) + } + + l := make([]*irc.Message, int(diff)) + for i := 0; i < int(diff); i++ { + j := int((rb.cur - diff + uint64(i)) % rb.cap()) + l[i] = rb.buf[j] + } + + return l, nil +} diff --git a/branches/origin/net_go113.go b/branches/origin/net_go113.go new file mode 100644 index 0000000..24bde7f --- /dev/null +++ b/branches/origin/net_go113.go @@ -0,0 +1,12 @@ +//go:build !go1.16 +// +build !go1.16 + +package suika + +import ( + "strings" +) + +func isErrClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "use of closed network connection") +} diff --git a/branches/origin/net_go116.go b/branches/origin/net_go116.go new file mode 100644 index 0000000..ef1256a --- /dev/null +++ b/branches/origin/net_go116.go @@ -0,0 +1,13 @@ +//go:build go1.16 +// +build go1.16 + +package suika + +import ( + "errors" + "net" +) + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) +} diff --git a/branches/origin/rate.go b/branches/origin/rate.go new file mode 100644 index 0000000..a14bce2 --- /dev/null +++ b/branches/origin/rate.go @@ -0,0 +1,40 @@ +package suika + +import ( + "math/rand" + "time" +) + +// backoffer implements a simple exponential backoff. +type backoffer struct { + min, max, jitter time.Duration + n int64 +} + +func newBackoffer(min, max, jitter time.Duration) *backoffer { + return &backoffer{min: min, max: max, jitter: jitter} +} + +func (b *backoffer) Reset() { + b.n = 0 +} + +func (b *backoffer) Next() time.Duration { + if b.n == 0 { + b.n = 1 + return 0 + } + + d := time.Duration(b.n) * b.min + if d > b.max { + d = b.max + } else { + b.n *= 2 + } + + if b.jitter != 0 { + d += time.Duration(rand.Int63n(int64(b.jitter))) + } + + return d +} diff --git a/branches/origin/rc.d/freebsd-rc.d b/branches/origin/rc.d/freebsd-rc.d new file mode 100755 index 0000000..0fed078 --- /dev/null +++ b/branches/origin/rc.d/freebsd-rc.d @@ -0,0 +1,29 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +desc="A drunk IRC bouncer" +rcvar="suika_enable" + +: ${suika_user="ircd"} + +command="%%PREFIX%%/bin/suika" +pidfile="/var/run/suika.pid" +required_files="%%PREFIX%%/etc/suika/config" + +start_cmd="suika_start" + +suika_start() { + /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files} +} + +load_rc_config "$name" +run_rc_command "$1" diff --git a/branches/origin/rc.d/immortal.yml b/branches/origin/rc.d/immortal.yml new file mode 100644 index 0000000..c7a8090 --- /dev/null +++ b/branches/origin/rc.d/immortal.yml @@ -0,0 +1,3 @@ +# $TheSupernovaDuo$ +cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +user: ircd diff --git a/branches/origin/rc.d/netbsd-rc.d b/branches/origin/rc.d/netbsd-rc.d new file mode 100644 index 0000000..3fef470 --- /dev/null +++ b/branches/origin/rc.d/netbsd-rc.d @@ -0,0 +1,28 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +rcvar="${name}" +command="%%PREFIX/bin/${name}" +command_args="--config %%PREFIX%%/etc/suika/config" +pidfile="/var/run/${name}.pid" +start_cmd="${name}_start" + +suika_start() { + printf "Starting %s..." "${name}" + ${command} ${command_args} + pgrep -n ${name} > ${pidfile} +} + +load_rc_config ${name} +run_rc_command "$1" + + diff --git a/branches/origin/rc.d/openbsd-rc.d b/branches/origin/rc.d/openbsd-rc.d new file mode 100644 index 0000000..f0bf77c --- /dev/null +++ b/branches/origin/rc.d/openbsd-rc.d @@ -0,0 +1,12 @@ +#!/bin/ksh +# $TheSupernovaDuo$ +# vim: ft=sh + +daemon="%%PREFIX%%/bin/suika" +daemon_args="--config %%PREFIX%%/etc/suika/config" + +. /etc/rc.d/rc.subr + +rc_bg=YES + +rc_cmd "$1" diff --git a/branches/origin/rc.d/suika.service b/branches/origin/rc.d/suika.service new file mode 100644 index 0000000..d8d53b3 --- /dev/null +++ b/branches/origin/rc.d/suika.service @@ -0,0 +1,16 @@ +# $TheSupernovaDuo$ +# vim: ft=confini +[Unit] +Description=A drunk IRC bouncer +After=network.target +Wants=network.target +StartLimitBurst=5 +StartLimitIntervalSec=1 +[Service] +Type=simple +Restart=on-abnormal +RestartSec=1 +User=suika +ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +[Install] +WantedBy=multi-user.target diff --git a/branches/origin/server.go b/branches/origin/server.go new file mode 100644 index 0000000..05c55fd --- /dev/null +++ b/branches/origin/server.go @@ -0,0 +1,330 @@ +package suika + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "gopkg.in/irc.v3" +) + +// TODO: make configurable +var ( + retryConnectMinDelay = time.Minute + retryConnectMaxDelay = 10 * time.Minute + retryConnectJitter = time.Minute + connectTimeout = 15 * time.Second + writeTimeout = 10 * time.Second + upstreamMessageDelay = 2 * time.Second + upstreamMessageBurst = 10 + backlogTimeout = 10 * time.Second + handleDownstreamMessageTimeout = 10 * time.Second + downstreamRegisterTimeout = 30 * time.Second + chatHistoryLimit = 1000 + backlogLimit = 4000 +) + +type Logger interface { + Printf(format string, v ...interface{}) + Debugf(format string, v ...interface{}) +} + +type logger struct { + *log.Logger + debug bool +} + +func (l logger) Debugf(format string, v ...interface{}) { + if !l.debug { + return + } + l.Logger.Printf(format, v...) +} + +func NewLogger(out io.Writer, debug bool) Logger { + return logger{ + Logger: log.New(log.Writer(), "", log.LstdFlags), + debug: debug, + } +} + +type prefixLogger struct { + logger Logger + prefix string +} + +var _ Logger = (*prefixLogger)(nil) + +func (l *prefixLogger) Printf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Printf("%v"+format, v...) +} + +func (l *prefixLogger) Debugf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Debugf("%v"+format, v...) +} + +type int64Gauge struct { + v int64 // atomic +} + +func (g *int64Gauge) Add(delta int64) { + atomic.AddInt64(&g.v, delta) +} + +func (g *int64Gauge) Value() int64 { + return atomic.LoadInt64(&g.v) +} + +func (g *int64Gauge) Float64() float64 { + return float64(g.Value()) +} + +type retryListener struct { + net.Listener + Logger Logger + + delay time.Duration +} + +func (ln *retryListener) Accept() (net.Conn, error) { + for { + conn, err := ln.Listener.Accept() + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if ln.delay == 0 { + ln.delay = 5 * time.Millisecond + } else { + ln.delay *= 2 + } + if max := 1 * time.Second; ln.delay > max { + ln.delay = max + } + if ln.Logger != nil { + ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err) + } + time.Sleep(ln.delay) + } else { + ln.delay = 0 + return conn, err + } + } +} + +type Config struct { + Hostname string + Title string + LogPath string + MaxUserNetworks int + MultiUpstream bool + MOTD string + UpstreamUserIPs []*net.IPNet +} + +type Server struct { + Logger Logger + + config atomic.Value // *Config + db Database + stopWG sync.WaitGroup + + lock sync.Mutex + listeners map[net.Listener]struct{} + users map[string]*user +} + +func NewServer(db Database) *Server { + srv := &Server{ + Logger: NewLogger(log.Writer(), true), + db: db, + listeners: make(map[net.Listener]struct{}), + users: make(map[string]*user), + } + srv.config.Store(&Config{ + Hostname: "localhost", + MaxUserNetworks: -1, + MultiUpstream: true, + }) + return srv +} + +func (s *Server) prefix() *irc.Prefix { + return &irc.Prefix{Name: s.Config().Hostname} +} + +func (s *Server) Config() *Config { + return s.config.Load().(*Config) +} + +func (s *Server) SetConfig(cfg *Config) { + s.config.Store(cfg) +} + +func (s *Server) Start() error { + users, err := s.db.ListUsers(context.TODO()) + if err != nil { + return err + } + + s.lock.Lock() + for i := range users { + s.addUserLocked(&users[i]) + } + s.lock.Unlock() + + return nil +} + +func (s *Server) Shutdown() { + s.lock.Lock() + for ln := range s.listeners { + if err := ln.Close(); err != nil { + s.Logger.Printf("failed to stop listener: %v", err) + } + } + for _, u := range s.users { + u.events <- eventStop{} + } + s.lock.Unlock() + + s.stopWG.Wait() + + if err := s.db.Close(); err != nil { + s.Logger.Printf("failed to close DB: %v", err) + } +} + +func (s *Server) createUser(ctx context.Context, user *User) (*user, error) { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.users[user.Username]; ok { + return nil, fmt.Errorf("user %q already exists", user.Username) + } + + err := s.db.StoreUser(ctx, user) + if err != nil { + return nil, fmt.Errorf("could not create user in db: %v", err) + } + + return s.addUserLocked(user), nil +} + +func (s *Server) forEachUser(f func(*user)) { + s.lock.Lock() + for _, u := range s.users { + f(u) + } + s.lock.Unlock() +} + +func (s *Server) getUser(name string) *user { + s.lock.Lock() + u := s.users[name] + s.lock.Unlock() + return u +} + +func (s *Server) addUserLocked(user *User) *user { + s.Logger.Printf("starting bouncer for user %q", user.Username) + u := newUser(s, user) + s.users[u.Username] = u + + s.stopWG.Add(1) + + go func() { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack()) + } + + s.lock.Lock() + delete(s.users, u.Username) + s.lock.Unlock() + + s.stopWG.Done() + }() + + u.run() + }() + + return u +} + +var lastDownstreamID uint64 = 0 + +func (s *Server) handle(ic ircConn) { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack()) + } + }() + + id := atomic.AddUint64(&lastDownstreamID, 1) + dc := newDownstreamConn(s, ic, id) + if err := dc.runUntilRegistered(); err != nil { + if !errors.Is(err, io.EOF) { + dc.logger.Printf("%v", err) + } + } else { + dc.user.events <- eventDownstreamConnected{dc} + if err := dc.readMessages(dc.user.events); err != nil { + dc.logger.Printf("%v", err) + } + dc.user.events <- eventDownstreamDisconnected{dc} + } + dc.Close() +} + +func (s *Server) Serve(ln net.Listener) error { + ln = &retryListener{ + Listener: ln, + Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())}, + } + + s.lock.Lock() + s.listeners[ln] = struct{}{} + s.lock.Unlock() + + s.stopWG.Add(1) + + defer func() { + s.lock.Lock() + delete(s.listeners, ln) + s.lock.Unlock() + + s.stopWG.Done() + }() + + for { + conn, err := ln.Accept() + if isErrClosed(err) { + return nil + } else if err != nil { + return fmt.Errorf("failed to accept connection: %v", err) + } + + go s.handle(newNetIRCConn(conn)) + } +} + +type ServerStats struct { + Users int + Downstreams int64 + Upstreams int64 +} + +func (s *Server) Stats() *ServerStats { + var stats ServerStats + s.lock.Lock() + stats.Users = len(s.users) + s.lock.Unlock() + return &stats +} diff --git a/branches/origin/server_test.go b/branches/origin/server_test.go new file mode 100644 index 0000000..cf0adca --- /dev/null +++ b/branches/origin/server_test.go @@ -0,0 +1,207 @@ +package suika + +import ( + "context" + "net" + "testing" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +var testServerPrefix = &irc.Prefix{Name: "suika-test-server"} + +const ( + testUsername = "suika-test-user" + testPassword = testUsername +) + +func createTempSqliteDB(t *testing.T) Database { + db, err := OpenDB("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + // :memory: will open a separate database for each new connection. Make + // sure the sql package only uses a single connection. An alternative + // solution is to use "file::memory:?cache=shared". + db.(*SqliteDB).db.SetMaxOpenConns(1) + return db +} + +func createTempPostgresDB(t *testing.T) Database { + db := &PostgresDB{db: openTempPostgresDB(t)} + if err := db.upgrade(); err != nil { + t.Fatalf("failed to upgrade PostgreSQL database: %v", err) + } + + return db +} + +func createTestUser(t *testing.T, db Database) *User { + hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("failed to generate bcrypt hash: %v", err) + } + + record := &User{Username: testUsername, Password: string(hashed)} + if err := db.StoreUser(context.Background(), record); err != nil { + t.Fatalf("failed to store test user: %v", err) + } + + return record +} + +func createTestDownstream(t *testing.T, srv *Server) ircConn { + c1, c2 := net.Pipe() + go srv.handle(newNetIRCConn(c1)) + return newNetIRCConn(c2) +} + +func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to create TCP listener: %v", err) + } + + network := &Network{ + Name: "testnet", + Addr: "irc://" + ln.Addr().String(), + Nick: user.Username, + Enabled: true, + } + if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil { + t.Fatalf("failed to store test network: %v", err) + } + + return network, ln +} + +func mustAccept(t *testing.T, ln net.Listener) ircConn { + c, err := ln.Accept() + if err != nil { + t.Fatalf("failed accepting connection: %v", err) + } + return newNetIRCConn(c) +} + +func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message { + msg, err := c.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message (want %q): %v", cmd, err) + } + if msg.Command != cmd { + t.Fatalf("invalid message received: want %q, got: %v", cmd, msg) + } + return msg +} + +func registerDownstreamConn(t *testing.T, c ircConn, network *Network) { + c.WriteMessage(&irc.Message{ + Command: "PASS", + Params: []string{testPassword}, + }) + c.WriteMessage(&irc.Message{ + Command: "NICK", + Params: []string{testUsername}, + }) + c.WriteMessage(&irc.Message{ + Command: "USER", + Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername}, + }) + + expectMessage(t, c, irc.RPL_WELCOME) +} + +func registerUpstreamConn(t *testing.T, c ircConn) { + msg := expectMessage(t, c, "CAP") + if msg.Params[0] != "LS" { + t.Fatalf("invalid CAP LS: got: %v", msg) + } + msg = expectMessage(t, c, "NICK") + nick := msg.Params[0] + if nick != testUsername { + t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg) + } + expectMessage(t, c, "USER") + + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_WELCOME, + Params: []string{nick, "Welcome!"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_YOURHOST, + Params: []string{nick, "Your host is suika-test-server"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_CREATED, + Params: []string{nick, "Who cares when the server was created?"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_MYINFO, + Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.ERR_NOMOTD, + Params: []string{nick, "No MOTD"}, + }) +} + +func testServer(t *testing.T, db Database) { + user := createTestUser(t, db) + network, upstream := createTestUpstream(t, db, user) + defer upstream.Close() + + srv := NewServer(db) + if err := srv.Start(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer srv.Shutdown() + + uc := mustAccept(t, upstream) + defer uc.Close() + registerUpstreamConn(t, uc) + + dc := createTestDownstream(t, srv) + defer dc.Close() + registerDownstreamConn(t, dc, network) + + noticeText := "This is a very important server notice." + uc.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: "NOTICE", + Params: []string{testUsername, noticeText}, + }) + + var msg *irc.Message + for { + var err error + msg, err = dc.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message: %v", err) + } + if msg.Command == "NOTICE" { + break + } + } + + if msg.Params[1] != noticeText { + t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg) + } +} + +func TestServer(t *testing.T) { + t.Run("sqlite", func(t *testing.T) { + db := createTempSqliteDB(t) + testServer(t, db) + }) + + t.Run("postgres", func(t *testing.T) { + db := createTempPostgresDB(t) + testServer(t, db) + }) +} diff --git a/branches/origin/service.go b/branches/origin/service.go new file mode 100644 index 0000000..07226ef --- /dev/null +++ b/branches/origin/service.go @@ -0,0 +1,1150 @@ +package suika + +import ( + "context" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "flag" + "fmt" + "io/ioutil" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +const ( + serviceNick = "BouncerServ" + serviceNickCM = "bouncerserv" + serviceRealname = "suika bouncer service" +) + +// maxRSABits is the maximum number of RSA key bits used when generating a new +// private key. +const maxRSABits = 8192 + +var servicePrefix = &irc.Prefix{ + Name: serviceNick, + User: serviceNick, + Host: serviceNick, +} + +type serviceCommandSet map[string]*serviceCommand + +type serviceCommand struct { + usage string + desc string + handle func(ctx context.Context, dc *downstreamConn, params []string) error + children serviceCommandSet + admin bool +} + +func sendServiceNOTICE(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{dc.nick, text}, + }) +} + +func sendServicePRIVMSG(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "PRIVMSG", + Params: []string{dc.nick, text}, + }) +} + +func splitWords(s string) ([]string, error) { + var words []string + var lastWord strings.Builder + escape := false + prev := ' ' + wordDelim := ' ' + + for _, r := range s { + if escape { + // last char was a backslash, write the byte as-is. + lastWord.WriteRune(r) + escape = false + } else if r == '\\' { + escape = true + } else if wordDelim == ' ' && unicode.IsSpace(r) { + // end of last word + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + lastWord.Reset() + } + } else if r == wordDelim { + // wordDelim is either " or ', switch back to + // space-delimited words. + wordDelim = ' ' + } else if r == '"' || r == '\'' { + if wordDelim == ' ' { + // start of (double-)quoted word + wordDelim = r + } else { + // either wordDelim is " and r is ' or vice-versa + lastWord.WriteRune(r) + } + } else { + lastWord.WriteRune(r) + } + + prev = r + } + + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + } + + if wordDelim != ' ' { + return nil, fmt.Errorf("unterminated quoted string") + } + if escape { + return nil, fmt.Errorf("unterminated backslash sequence") + } + + return words, nil +} + +func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) { + words, err := splitWords(text) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err)) + return + } + + cmd, params, err := serviceCommands.Get(words) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err)) + return + } + if cmd.admin && !dc.user.Admin { + sendServicePRIVMSG(dc, "error: you must be an admin to use this command") + return + } + + if cmd.handle == nil { + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + // Pretend the command does not exist if it has neither children nor handler. + // This is obviously a bug but it is better to not die anyway. + dc.logger.Printf("command without handler and subcommands invoked:", words[0]) + sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0])) + } + return + } + + if err := cmd.handle(ctx, dc, params); err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err)) + } +} + +func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) { + if len(params) == 0 { + return nil, nil, fmt.Errorf("no command specified") + } + + name := params[0] + params = params[1:] + + cmd, ok := cmds[name] + if !ok { + for k := range cmds { + if !strings.HasPrefix(k, name) { + continue + } + if cmd != nil { + return nil, params, fmt.Errorf("command %q is ambiguous", name) + } + cmd = cmds[k] + } + } + if cmd == nil { + return nil, params, fmt.Errorf("command %q not found", name) + } + + if len(params) == 0 || len(cmd.children) == 0 { + return cmd, params, nil + } + return cmd.children.Get(params) +} + +func (cmds serviceCommandSet) Names() []string { + l := make([]string, 0, len(cmds)) + for name := range cmds { + l = append(l, name) + } + sort.Strings(l) + return l +} + +var serviceCommands serviceCommandSet + +func init() { + serviceCommands = serviceCommandSet{ + "help": { + usage: "[command]", + desc: "print help message", + handle: handleServiceHelp, + }, + "network": { + children: serviceCommandSet{ + "create": { + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "add a new network", + handle: handleServiceNetworkCreate, + }, + "status": { + desc: "show a list of saved networks and their current status", + handle: handleServiceNetworkStatus, + }, + "update": { + usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "update a network", + handle: handleServiceNetworkUpdate, + }, + "delete": { + usage: "[name]", + desc: "delete a network", + handle: handleServiceNetworkDelete, + }, + "quote": { + usage: "[name] ", + desc: "send a raw line to a network", + handle: handleServiceNetworkQuote, + }, + }, + }, + "certfp": { + children: serviceCommandSet{ + "generate": { + usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]", + desc: "generate a new self-signed certificate, defaults to using RSA-3072 key", + handle: handleServiceCertFPGenerate, + }, + "fingerprint": { + usage: "[-network name]", + desc: "show fingerprints of certificate", + handle: handleServiceCertFPFingerprints, + }, + }, + }, + "sasl": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show SASL status", + handle: handleServiceSASLStatus, + }, + "set-plain": { + usage: "[-network name] ", + desc: "set SASL PLAIN credentials", + handle: handleServiceSASLSetPlain, + }, + "reset": { + usage: "[-network name]", + desc: "disable SASL authentication and remove stored credentials", + handle: handleServiceSASLReset, + }, + }, + }, + "user": { + children: serviceCommandSet{ + "create": { + usage: "-username -password [-realname ] [-admin]", + desc: "create a new suika user", + handle: handleUserCreate, + admin: true, + }, + "update": { + usage: "[-password ] [-realname ]", + desc: "update the current user", + handle: handleUserUpdate, + }, + "delete": { + usage: "", + desc: "delete a user", + handle: handleUserDelete, + admin: true, + }, + }, + }, + "channel": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show a list of saved channels and their current status", + handle: handleServiceChannelStatus, + }, + "update": { + usage: " [-relay-detached ] [-reattach-on ] [-detach-after ] [-detach-on ]", + desc: "update a channel", + handle: handleServiceChannelUpdate, + }, + }, + }, + "server": { + children: serviceCommandSet{ + "status": { + desc: "show server statistics", + handle: handleServiceServerStatus, + admin: true, + }, + "notice": { + desc: "broadcast a notice to all connected bouncer users", + handle: handleServiceServerNotice, + admin: true, + }, + }, + admin: true, + }, + } +} + +func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) { + for _, name := range cmds.Names() { + cmd := cmds[name] + if cmd.admin && !admin { + continue + } + words := append(prefix, name) + if len(cmd.children) == 0 { + s := strings.Join(words, " ") + *l = append(*l, s) + } else { + appendServiceCommandSetHelp(cmd.children, words, admin, l) + } + } +} + +func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) > 0 { + cmd, rest, err := serviceCommands.Get(params) + if err != nil { + return err + } + words := params[:len(params)-len(rest)] + + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + text := strings.Join(words, " ") + if cmd.usage != "" { + text += " " + cmd.usage + } + text += ": " + cmd.desc + + sendServicePRIVMSG(dc, text) + } + } else { + var l []string + appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } + return nil +} + +func newFlagSet() *flag.FlagSet { + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.SetOutput(ioutil.Discard) + return fs +} + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +// stringPtrFlag is a flag value populating a string pointer. This allows to +// disambiguate between a flag that hasn't been set and a flag that has been +// set to an empty string. +type stringPtrFlag struct { + ptr **string +} + +func (f stringPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return **f.ptr +} + +func (f stringPtrFlag) Set(s string) error { + *f.ptr = &s + return nil +} + +type boolPtrFlag struct { + ptr **bool +} + +func (f boolPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return strconv.FormatBool(**f.ptr) +} + +func (f boolPtrFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + *f.ptr = &v + return nil +} + +func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) { + name, params := popArg(params) + if name == "" { + if dc.network == nil { + return nil, params, fmt.Errorf("no network selected, a name argument is required") + } + return dc.network, params, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, params, fmt.Errorf("unknown network %q", name) + } + return net, params, nil + } +} + +type networkFlagSet struct { + *flag.FlagSet + Addr, Name, Nick, Username, Pass, Realname *string + Enabled *bool + ConnectCommands []string +} + +func newNetworkFlagSet() *networkFlagSet { + fs := &networkFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.Addr}, "addr", "") + fs.Var(stringPtrFlag{&fs.Name}, "name", "") + fs.Var(stringPtrFlag{&fs.Nick}, "nick", "") + fs.Var(stringPtrFlag{&fs.Username}, "username", "") + fs.Var(stringPtrFlag{&fs.Pass}, "pass", "") + fs.Var(stringPtrFlag{&fs.Realname}, "realname", "") + fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "") + fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "") + return fs +} + +func (fs *networkFlagSet) update(network *Network) error { + if fs.Addr != nil { + if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 { + scheme := addrParts[0] + switch scheme { + case "ircs", "irc", "unix": + default: + return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme) + } + } + network.Addr = *fs.Addr + } + if fs.Name != nil { + network.Name = *fs.Name + } + if fs.Nick != nil { + network.Nick = *fs.Nick + } + if fs.Username != nil { + network.Username = *fs.Username + } + if fs.Pass != nil { + network.Pass = *fs.Pass + } + if fs.Realname != nil { + network.Realname = *fs.Realname + } + if fs.Enabled != nil { + network.Enabled = *fs.Enabled + } + if fs.ConnectCommands != nil { + if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" { + network.ConnectCommands = nil + } else { + for _, command := range fs.ConnectCommands { + _, err := irc.ParseMessage(command) + if err != nil { + return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) + } + } + network.ConnectCommands = fs.ConnectCommands + } + } + return nil +} + +func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + if fs.Addr == nil { + return fmt.Errorf("flag -addr is required") + } + + record := &Network{ + Addr: *fs.Addr, + Enabled: true, + } + if err := fs.update(record); err != nil { + return err + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return fmt.Errorf("could not create network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName())) + return nil +} + +func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error { + n := 0 + for _, net := range dc.user.networks { + var statuses []string + var details string + if uc := net.conn; uc != nil { + if dc.nick != uc.nick { + statuses = append(statuses, "connected as "+uc.nick) + } else { + statuses = append(statuses, "connected") + } + details = fmt.Sprintf("%v channels", uc.channels.Len()) + } else if !net.Enabled { + statuses = append(statuses, "disabled") + } else { + statuses = append(statuses, "disconnected") + if net.lastError != nil { + details = net.lastError.Error() + } + } + + if net == dc.network { + statuses = append(statuses, "current") + } + + name := net.GetName() + if name != net.Addr { + name = fmt.Sprintf("%v (%v)", name, net.Addr) + } + + s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", ")) + if details != "" { + s += ": " + details + } + sendServicePRIVMSG(dc, s) + + n++ + } + + if n == 0 { + sendServicePRIVMSG(dc, `No network configured, add one with "network create".`) + } + + return nil +} + +func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + + record := net.Network // copy network record because we'll mutate it + if err := fs.update(&record); err != nil { + return err + } + + network, err := dc.user.updateNetwork(ctx, &record) + if err != nil { + return fmt.Errorf("could not update network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName())) + return nil +} + +func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName())) + return nil +} + +func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 && len(params) != 2 { + return fmt.Errorf("expected one or two arguments") + } + + raw := params[len(params)-1] + params = params[:len(params)-1] + + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + uc := net.conn + if uc == nil { + return fmt.Errorf("network %q is not currently connected", net.GetName()) + } + + m, err := irc.ParseMessage(raw) + if err != nil { + return fmt.Errorf("failed to parse command %q: %v", raw, err) + } + uc.SendMessage(ctx, m) + + sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName())) + return nil +} + +func sendCertfpFingerprints(dc *downstreamConn, cert []byte) { + sha1Sum := sha1.Sum(cert) + sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:])) + sha256Sum := sha256.Sum256(cert) + sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:])) + sha512Sum := sha512.Sum512(cert) + sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:])) +} + +func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) { + if name == "" { + if dc.network == nil { + return nil, fmt.Errorf("no network selected, -network is required") + } + return dc.network, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, fmt.Errorf("unknown network %q", name) + } + return net, nil + } +} + +func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)") + bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA") + + if err := fs.Parse(params); err != nil { + return err + } + + if *bits <= 0 || *bits > maxRSABits { + return fmt.Errorf("invalid value for -bits") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + privKey, cert, err := generateCertFP(*keyType, *bits) + if err != nil { + return err + } + + net.SASL.External.CertBlob = cert + net.SASL.External.PrivKeyBlob = privKey + net.SASL.Mechanism = "EXTERNAL" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "certificate generated") + sendCertfpFingerprints(dc, cert) + return nil +} + +func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + if net.SASL.Mechanism != "EXTERNAL" { + return fmt.Errorf("CertFP not set up") + } + + sendCertfpFingerprints(dc, net.SASL.External.CertBlob) + return nil +} + +func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + switch net.SASL.Mechanism { + case "PLAIN": + sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username)) + case "EXTERNAL": + sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled") + case "": + sendServicePRIVMSG(dc, "SASL is disabled") + } + + if uc := net.conn; uc != nil { + if uc.account != "" { + sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account)) + } else { + sendServicePRIVMSG(dc, "Unauthenticated on upstream network") + } + } else { + sendServicePRIVMSG(dc, "Disconnected from upstream network") + } + + return nil +} + +func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + if len(fs.Args()) != 2 { + return fmt.Errorf("expected exactly 2 arguments") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = fs.Arg(0) + net.SASL.Plain.Password = fs.Arg(1) + net.SASL.Mechanism = "PLAIN" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials saved") + return nil +} + +func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = "" + net.SASL.Plain.Password = "" + net.SASL.External.CertBlob = nil + net.SASL.External.PrivKeyBlob = nil + net.SASL.Mechanism = "" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials reset") + return nil +} + +func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + username := fs.String("username", "", "") + password := fs.String("password", "", "") + realname := fs.String("realname", "", "") + admin := fs.Bool("admin", false, "") + + if err := fs.Parse(params); err != nil { + return err + } + if *username == "" { + return fmt.Errorf("flag -username is required") + } + if *password == "" { + return fmt.Errorf("flag -password is required") + } + + hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + + user := &User{ + Username: *username, + Password: string(hashed), + Realname: *realname, + Admin: *admin, + } + if _, err := dc.srv.createUser(ctx, user); err != nil { + return fmt.Errorf("could not create user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username)) + return nil +} + +func popArg(params []string) (string, []string) { + if len(params) > 0 && !strings.HasPrefix(params[0], "-") { + return params[0], params[1:] + } + return "", params +} + +func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + var password, realname *string + var admin *bool + fs := newFlagSet() + fs.Var(stringPtrFlag{&password}, "password", "") + fs.Var(stringPtrFlag{&realname}, "realname", "") + fs.Var(boolPtrFlag{&admin}, "admin", "") + + username, params := popArg(params) + if err := fs.Parse(params); err != nil { + return err + } + if len(fs.Args()) > 0 { + return fmt.Errorf("unexpected argument") + } + + var hashed *string + if password != nil { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + hashedStr := string(hashedBytes) + hashed = &hashedStr + } + + if username != "" && username != dc.user.Username { + if !dc.user.Admin { + return fmt.Errorf("you must be an admin to update other users") + } + if realname != nil { + return fmt.Errorf("cannot update -realname of other user") + } + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + done := make(chan error, 1) + event := eventUserUpdate{ + password: hashed, + admin: admin, + done: done, + } + select { + case <-ctx.Done(): + return ctx.Err() + case u.events <- event: + } + // TODO: send context to the other side + if err := <-done; err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username)) + } else { + // copy the user record because we'll mutate it + record := dc.user.User + + if hashed != nil { + record.Password = *hashed + } + if realname != nil { + record.Realname = *realname + } + if admin != nil { + return fmt.Errorf("cannot update -admin of own user") + } + + if err := dc.user.updateUser(ctx, &record); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username)) + } + + return nil +} + +func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + username := params[0] + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + u.stop() + + if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil { + return fmt.Errorf("failed to delete user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username)) + return nil +} + +func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error { + var defaultNetworkName string + if dc.network != nil { + defaultNetworkName = dc.network.GetName() + } + + fs := newFlagSet() + networkName := fs.String("network", defaultNetworkName, "") + + if err := fs.Parse(params); err != nil { + return err + } + + n := 0 + + sendNetwork := func(net *network) { + var channels []*Channel + for _, entry := range net.channels.innerMap { + channels = append(channels, entry.value.(*Channel)) + } + + sort.Slice(channels, func(i, j int) bool { + return strings.ReplaceAll(channels[i].Name, "#", "") < + strings.ReplaceAll(channels[j].Name, "#", "") + }) + + for _, ch := range channels { + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + } + + name := ch.Name + if *networkName == "" { + name += "/" + net.GetName() + } + + var status string + if uch != nil { + status = "joined" + } else if net.conn != nil { + status = "parted" + } else { + status = "disconnected" + } + + if ch.Detached { + status += ", detached" + } + + s := fmt.Sprintf("%v [%v]", name, status) + sendServicePRIVMSG(dc, s) + + n++ + } + } + + if *networkName == "" { + for _, net := range dc.user.networks { + sendNetwork(net) + } + } else { + net := dc.user.getNetwork(*networkName) + if net == nil { + return fmt.Errorf("unknown network %q", *networkName) + } + sendNetwork(net) + } + + if n == 0 { + sendServicePRIVMSG(dc, "No channel configured.") + } + + return nil +} + +type channelFlagSet struct { + *flag.FlagSet + RelayDetached, ReattachOn, DetachAfter, DetachOn *string +} + +func newChannelFlagSet() *channelFlagSet { + fs := &channelFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "") + fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "") + fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "") + fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "") + return fs +} + +func (fs *channelFlagSet) update(channel *Channel) error { + if fs.RelayDetached != nil { + filter, err := parseFilter(*fs.RelayDetached) + if err != nil { + return err + } + channel.RelayDetached = filter + } + if fs.ReattachOn != nil { + filter, err := parseFilter(*fs.ReattachOn) + if err != nil { + return err + } + channel.ReattachOn = filter + } + if fs.DetachAfter != nil { + dur, err := time.ParseDuration(*fs.DetachAfter) + if err != nil || dur < 0 { + return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter) + } + channel.DetachAfter = dur + } + if fs.DetachOn != nil { + filter, err := parseFilter(*fs.DetachOn) + if err != nil { + return err + } + channel.DetachOn = filter + } + return nil +} + +func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) < 1 { + return fmt.Errorf("expected at least one argument") + } + name := params[0] + + fs := newChannelFlagSet() + if err := fs.Parse(params[1:]); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return fmt.Errorf("unknown channel %q", name) + } + + ch := uc.network.channels.Value(upstreamName) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + + if err := fs.update(ch); err != nil { + return err + } + + uc.updateChannelAutoDetach(upstreamName) + + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + return fmt.Errorf("failed to update channel: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name)) + return nil +} +func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error { + dbStats, err := dc.user.srv.db.Stats(ctx) + if err != nil { + return err + } + serverStats := dc.user.srv.Stats() + sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels)) + return nil +} + +func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + text := params[0] + + dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text) + + broadcastMsg := &irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{"$" + dc.srv.Config().Hostname, text}, + } + var err error + sent := 0 + total := 0 + dc.srv.forEachUser(func(u *user) { + total++ + select { + case <-ctx.Done(): + err = ctx.Err() + case u.events <- eventBroadcast{broadcastMsg}: + sent++ + } + }) + + dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total) + sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total)) + + return err +} diff --git a/branches/origin/service_test.go b/branches/origin/service_test.go new file mode 100644 index 0000000..1f3fe6b --- /dev/null +++ b/branches/origin/service_test.go @@ -0,0 +1,54 @@ +package suika + +import ( + "testing" +) + +func assertSplit(t *testing.T, input string, expected []string) { + actual, err := splitWords(input) + if err != nil { + t.Errorf("%q: %v", input, err) + return + } + if len(actual) != len(expected) { + t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual) + return + } + for i := 0; i < len(actual); i++ { + if actual[i] != expected[i] { + t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual) + } + } +} + +func TestSplit(t *testing.T) { + assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{ + "ch", + "up", + "#suika", + "relay-detached", + "message", + }) + assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{ + "net", + "update", + "\"free\"node", + "-pass", + "political \"stance\" desu!", + "-realname", + "", + "-nick", + "lee", + }) + assertSplit(t, "Omedeto,\\ Yui! ''", []string{ + "Omedeto, Yui!", + "", + }) + + if _, err := splitWords("end of 'file"); err == nil { + t.Errorf("expected error on unterminated single quote") + } + if _, err := splitWords("end of backquote \\"); err == nil { + t.Errorf("expected error on unterminated backquote sequence") + } +} diff --git a/branches/origin/suika_psql_schema.sql b/branches/origin/suika_psql_schema.sql new file mode 100644 index 0000000..d148b4d --- /dev/null +++ b/branches/origin/suika_psql_schema.sql @@ -0,0 +1,58 @@ +CREATE TABLE IF NOT EXISTS "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + +CREATE TABLE IF NOT EXISTS "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255), + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism sasl_mechanism, + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA, + sasl_external_key BYTEA, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); +CREATE TABLE IF NOT EXISTS "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); +CREATE TABLE IF NOT EXISTS "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) +); + diff --git a/branches/origin/suika_sqlite_schema.sql b/branches/origin/suika_sqlite_schema.sql new file mode 100644 index 0000000..517d06f --- /dev/null +++ b/branches/origin/suika_sqlite_schema.sql @@ -0,0 +1,61 @@ +CREATE TABLE IF NOT EXISTS User ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password TEXT, + admin INTEGER NOT NULL DEFAULT 0, + realname TEXT +); +CREATE TABLE IF NOT EXISTS Network ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); +CREATE TABLE IF NOT EXISTS Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name TEXT NOT NULL, + key TEXT, + detached INTEGER NOT NULL DEFAULT 0, + detached_internal_msgid TEXT, + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + client TEXT, + internal_msgid TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) +); + +CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) +); + diff --git a/branches/origin/upstream.go b/branches/origin/upstream.go new file mode 100644 index 0000000..285dca4 --- /dev/null +++ b/branches/origin/upstream.go @@ -0,0 +1,2182 @@ +package suika + +import ( + "context" + "crypto" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "gopkg.in/irc.v3" +) + +// permanentUpstreamCaps is the static list of upstream capabilities always +// requested when supported. +var permanentUpstreamCaps = map[string]bool{ + "account-notify": true, + "account-tag": true, + "away-notify": true, + "batch": true, + "extended-join": true, + "invite-notify": true, + "labeled-response": true, + "message-tags": true, + "multi-prefix": true, + "sasl": true, + "server-time": true, + "setname": true, + + "draft/account-registration": true, + "draft/extended-monitor": true, +} + +type registrationError struct { + *irc.Message +} + +func (err registrationError) Error() string { + return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason()) +} + +func (err registrationError) Reason() string { + if len(err.Params) > 0 { + return err.Params[len(err.Params)-1] + } + return err.Command +} + +func (err registrationError) Temporary() bool { + // Only return false if we're 100% sure that fixing the error requires a + // network configuration change + switch err.Command { + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME: + return false + case "FAIL": + return err.Params[1] != "ACCOUNT_REQUIRED" + default: + return true + } +} + +type upstreamChannel struct { + Name string + conn *upstreamConn + Topic string + TopicWho *irc.Prefix + TopicTime time.Time + Status channelStatus + modes channelModes + creationTime string + Members membershipsCasemapMap + complete bool + detachTimer *time.Timer +} + +func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) { + if uc.detachTimer != nil { + uc.detachTimer.Stop() + uc.detachTimer = nil + } + + if dur == 0 { + return + } + + uc.detachTimer = time.AfterFunc(dur, func() { + uc.conn.network.user.events <- eventChannelDetach{ + uc: uc.conn, + name: uc.Name, + } + }) +} + +type pendingUpstreamCommand struct { + downstreamID uint64 + msg *irc.Message +} + +type upstreamConn struct { + conn + + network *network + user *user + + serverName string + availableUserModes string + availableChannelModes map[byte]channelModeType + availableChannelTypes string + availableMemberships []membership + isupport map[string]*string + + registered bool + nick string + nickCM string + username string + realname string + modes userModes + channels upstreamChannelCasemapMap + supportedCaps map[string]string + caps map[string]bool + batches map[string]batch + away bool + account string + nextLabelID uint64 + monitored monitorCasemapMap + + saslClient sasl.Client + saslStarted bool + + casemapIsSet bool + + // Queue of commands in progress, indexed by type. The first entry has been + // sent to the server and is awaiting reply. The following entries have not + // been sent yet. + pendingCmds map[string][]pendingUpstreamCommand + + gotMotd bool +} + +func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) { + logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())} + + dialer := net.Dialer{Timeout: connectTimeout} + + u, err := network.URL() + if err != nil { + return nil, err + } + + var netConn net.Conn + switch u.Scheme { + case "ircs": + addr := u.Host + host, _, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Host + addr = u.Host + ":6697" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to TLS server at address %q", addr) + + tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}} + if network.SASL.Mechanism == "EXTERNAL" { + if network.SASL.External.CertBlob == nil { + return nil, fmt.Errorf("missing certificate for authentication") + } + if network.SASL.External.PrivKeyBlob == nil { + return nil, fmt.Errorf("missing private key for authentication") + } + key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + tlsConfig.Certificates = []tls.Certificate{ + { + Certificate: [][]byte{network.SASL.External.CertBlob}, + PrivateKey: key.(crypto.PrivateKey), + }, + } + logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) + } + + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + + // Don't do the TLS handshake immediately, because we need to register + // the new connection with identd ASAP. + netConn = tls.Client(netConn, tlsConfig) + case "irc": + addr := u.Host + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = u.Host + addr = u.Host + ":6667" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to plain-text server at address %q", addr) + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + case "irc+unix", "unix": + logger.Printf("connecting to Unix socket at path %q", u.Path) + netConn, err = dialer.DialContext(ctx, "unix", u.Path) + if err != nil { + return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err) + } + default: + return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme) + } + + options := connOptions{ + Logger: logger, + RateLimitDelay: upstreamMessageDelay, + RateLimitBurst: upstreamMessageBurst, + } + + uc := &upstreamConn{ + conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), + network: network, + user: network.user, + channels: upstreamChannelCasemapMap{newCasemapMap(0)}, + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + batches: make(map[string]batch), + availableChannelTypes: stdChannelTypes, + availableChannelModes: stdChannelModes, + availableMemberships: stdMemberships, + isupport: make(map[string]*string), + pendingCmds: make(map[string][]pendingUpstreamCommand), + monitored: monitorCasemapMap{newCasemapMap(0)}, + } + return uc, nil +} + +func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { + uc.network.forEachDownstream(f) +} + +func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) { + uc.forEachDownstream(func(dc *downstreamConn) { + if id != 0 && id != dc.id { + return + } + f(dc) + }) +} + +func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn { + for _, dc := range uc.user.downstreamConns { + if dc.id == id { + return dc + } + } + return nil +} + +func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { + ch := uc.channels.Value(name) + if ch == nil { + return nil, fmt.Errorf("unknown channel %q", name) + } + return ch, nil +} + +func (uc *upstreamConn) isChannel(entity string) bool { + return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0])) +} + +func (uc *upstreamConn) isOurNick(nick string) bool { + return uc.nickCM == uc.network.casemap(nick) +} + +func (uc *upstreamConn) abortPendingCommands() { + for _, l := range uc.pendingCmds { + for _, pendingCmd := range l { + dc := uc.downstreamByID(pendingCmd.downstreamID) + if dc == nil { + continue + } + + switch pendingCmd.msg.Command { + case "LIST": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Command aborted"}, + }) + case "WHO": + mask := "*" + if len(pendingCmd.msg.Params) > 0 { + mask = pendingCmd.msg.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "Command aborted"}, + }) + case "AUTHENTICATE": + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + case "REGISTER", "VERIFY": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "FAIL", + Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"}, + }) + default: + panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command)) + } + } + } + + uc.pendingCmds = make(map[string][]pendingUpstreamCommand) +} + +func (uc *upstreamConn) sendNextPendingCommand(cmd string) { + if len(uc.pendingCmds[cmd]) == 0 { + return + } + uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg) +} + +func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { + switch msg.Command { + case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY": + // Supported + default: + panic(fmt.Errorf("Unsupported pending command %q", msg.Command)) + } + + uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{ + downstreamID: dc.id, + msg: msg, + }) + + if len(uc.pendingCmds[msg.Command]) == 1 { + uc.sendNextPendingCommand(msg.Command) + } +} + +func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) { + if len(uc.pendingCmds[cmd]) == 0 { + return nil, nil + } + + pendingCmd := uc.pendingCmds[cmd][0] + return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg +} + +func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) { + dc, msg := uc.currentPendingCommand(cmd) + + if len(uc.pendingCmds[cmd]) > 0 { + copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:]) + uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1] + } + + uc.sendNextPendingCommand(cmd) + + return dc, msg +} + +func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) { + for cmd := range uc.pendingCmds { + // We can't cancel the currently running command stored in + // uc.pendingCmds[cmd][0] + for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- { + if uc.pendingCmds[cmd][i].downstreamID == downstreamID { + uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...) + } + } + } +} + +func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) { + memberships := make(memberships, 0, 4) + i := 0 + for _, m := range uc.availableMemberships { + if i >= len(s) { + break + } + if s[i] == m.Prefix { + memberships = append(memberships, m) + i++ + } + } + return &memberships, s[i:] +} + +func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + var label string + if l, ok := msg.GetTag("label"); ok { + label = l + delete(msg.Tags, "label") + } + + var msgBatch *batch + if batchName, ok := msg.GetTag("batch"); ok { + b, ok := uc.batches[batchName] + if !ok { + return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName) + } + msgBatch = &b + if label == "" { + label = msgBatch.Label + } + delete(msg.Tags, "batch") + } + + var downstreamID uint64 = 0 + if label != "" { + var labelOffset uint64 + n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset) + if err == nil && n < 2 { + err = errors.New("not enough arguments") + } + if err != nil { + return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err) + } + } + + if _, ok := msg.Tags["time"]; !ok { + msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now())) + } + + switch msg.Command { + case "PING": + uc.SendMessage(ctx, &irc.Message{ + Command: "PONG", + Params: msg.Params, + }) + return nil + case "NOTICE", "PRIVMSG", "TAGMSG": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var entity, text string + if msg.Command != "TAGMSG" { + if err := parseMessageParams(msg, &entity, &text); err != nil { + return err + } + } else { + if err := parseMessageParams(msg, &entity); err != nil { + return err + } + } + + if msg.Prefix.Name == serviceNick { + uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg) + break + } + if entity == serviceNick { + uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg) + break + } + + if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message + uc.produce("", msg, nil) + } else { // regular user message + target := entity + if uc.isOurNick(target) { + target = msg.Prefix.Name + } + + ch := uc.network.channels.Value(target) + if ch != nil && msg.Command != "TAGMSG" { + if ch.Detached { + uc.handleDetachedMessage(ctx, ch, msg) + } + + highlight := uc.network.isHighlight(msg) + if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) { + uc.updateChannelAutoDetach(target) + } + } + + uc.produce(target, msg, nil) + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, nil, &subCmd); err != nil { + return err + } + subCmd = strings.ToUpper(subCmd) + subParams := msg.Params[2:] + switch subCmd { + case "LS": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := subParams[len(subParams)-1] + more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*" + + uc.handleSupportedCaps(caps) + + if more { + break // wait to receive all capabilities + } + + uc.requestCaps() + + if uc.requestSASL() { + break // we'll send CAP END after authentication is completed + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + case "ACK", "NAK": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, name := range caps { + if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil { + return err + } + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + case "NEW": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + uc.handleSupportedCaps(subParams[0]) + uc.requestCaps() + case "DEL": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, c := range caps { + delete(uc.supportedCaps, c) + delete(uc.caps, c) + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + default: + uc.logger.Printf("unhandled message: %v", msg) + } + case "AUTHENTICATE": + if uc.saslClient == nil { + return fmt.Errorf("received unexpected AUTHENTICATE message") + } + + // TODO: if a challenge is 400 bytes long, buffer it + var challengeStr string + if err := parseMessageParams(msg, &challengeStr); err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + var challenge []byte + if challengeStr != "+" { + var err error + challenge, err = base64.StdEncoding.DecodeString(challengeStr) + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + } + + var resp []byte + var err error + if !uc.saslStarted { + _, resp, err = uc.saslClient.Start() + uc.saslStarted = true + } else { + resp, err = uc.saslClient.Next(challenge) + } + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + // <= instead of < because we need to send a final empty response if + // the last chunk is exactly 400 bytes long + for i := 0; i <= len(resp); i += maxSASLLength { + j := i + maxSASLLength + if j > len(resp) { + j = len(resp) + } + + chunk := resp[i:j] + + var respStr = "+" + if len(chunk) != 0 { + respStr = base64.StdEncoding.EncodeToString(chunk) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{respStr}, + }) + } + case irc.RPL_LOGGEDIN: + if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil { + return err + } + uc.logger.Printf("logged in with account %q", uc.account) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.RPL_LOGGEDOUT: + uc.account = "" + uc.logger.Printf("logged out") + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED: + var info string + if err := parseMessageParams(msg, nil, &info); err != nil { + return err + } + switch msg.Command { + case irc.ERR_NICKLOCKED: + uc.logger.Printf("invalid nick used with SASL authentication: %v", info) + case irc.ERR_SASLFAIL: + uc.logger.Printf("SASL authentication failed: %v", info) + case irc.ERR_SASLTOOLONG: + uc.logger.Printf("SASL message too long: %v", info) + } + + uc.saslClient = nil + uc.saslStarted = false + + if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil { + if msg.Command == irc.RPL_SASLSUCCESS { + uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword) + } + + dc.endSASL(msg) + } + + if !uc.registered { + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + } + case "REGISTER", "VERIFY": + if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil { + if msg.Command == "REGISTER" { + var account, password string + if err := parseMessageParams(msg, nil, &account); err != nil { + return err + } + if err := parseMessageParams(cmd, nil, nil, &password); err != nil { + return err + } + uc.network.autoSaveSASLPlain(ctx, account, password) + } + + dc.SendMessage(msg) + } + case irc.RPL_WELCOME: + if err := parseMessageParams(msg, &uc.nick); err != nil { + return err + } + + uc.registered = true + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("connection registered with nick %q", uc.nick) + + if uc.network.channels.Len() > 0 { + var channels, keys []string + for _, entry := range uc.network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, ch.Name) + keys = append(keys, ch.Key) + } + + for _, msg := range join(channels, keys) { + uc.SendMessage(ctx, msg) + } + } + case irc.RPL_MYINFO: + if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil { + return err + } + case irc.RPL_ISUPPORT: + if err := parseMessageParams(msg, nil, nil); err != nil { + return err + } + + var downstreamIsupport []string + for _, token := range msg.Params[1 : len(msg.Params)-1] { + parameter := token + var negate, hasValue bool + var value string + if strings.HasPrefix(token, "-") { + negate = true + token = token[1:] + } else if i := strings.IndexByte(token, '='); i >= 0 { + parameter = token[:i] + value = token[i+1:] + hasValue = true + } + + if hasValue { + uc.isupport[parameter] = &value + } else if !negate { + uc.isupport[parameter] = nil + } else { + delete(uc.isupport, parameter) + } + + var err error + switch parameter { + case "CASEMAPPING": + casemap, ok := parseCasemappingToken(value) + if !ok { + casemap = casemapRFC1459 + } + uc.network.updateCasemapping(casemap) + uc.nickCM = uc.network.casemap(uc.nick) + uc.casemapIsSet = true + case "CHANMODES": + if !negate { + err = uc.handleChanModes(value) + } else { + uc.availableChannelModes = stdChannelModes + } + case "CHANTYPES": + if !negate { + uc.availableChannelTypes = value + } else { + uc.availableChannelTypes = stdChannelTypes + } + case "PREFIX": + if !negate { + err = uc.handleMemberships(value) + } else { + uc.availableMemberships = stdMemberships + } + } + if err != nil { + return err + } + + if passthroughIsupport[parameter] { + downstreamIsupport = append(downstreamIsupport, token) + } + } + + uc.updateMonitor() + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil { + return + } + msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport) + for _, msg := range msgs { + dc.SendMessage(msg) + } + }) + case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD: + if !uc.casemapIsSet { + // upstream did not send any CASEMAPPING token, thus + // we assume it implements the old RFCs with rfc1459. + uc.casemapIsSet = true + uc.network.updateCasemapping(casemapRFC1459) + uc.nickCM = uc.network.casemap(uc.nick) + } + + if !uc.gotMotd { + // Ignore the initial MOTD upon connection, but forward + // subsequent MOTD messages downstream + uc.gotMotd = true + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case "BATCH": + var tag string + if err := parseMessageParams(msg, &tag); err != nil { + return err + } + + if strings.HasPrefix(tag, "+") { + tag = tag[1:] + if _, ok := uc.batches[tag]; ok { + return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag) + } + var batchType string + if err := parseMessageParams(msg, nil, &batchType); err != nil { + return err + } + label := label + if label == "" && msgBatch != nil { + label = msgBatch.Label + } + uc.batches[tag] = batch{ + Type: batchType, + Params: msg.Params[2:], + Outer: msgBatch, + Label: label, + } + } else if strings.HasPrefix(tag, "-") { + tag = tag[1:] + if _, ok := uc.batches[tag]; !ok { + return fmt.Errorf("unknown BATCH reference tag: %q", tag) + } + delete(uc.batches, tag) + } else { + return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag) + } + case "NICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newNick string + if err := parseMessageParams(msg, &newNick); err != nil { + return err + } + + me := false + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick) + me = true + uc.nick = newNick + uc.nickCM = uc.network.casemap(uc.nick) + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + memberships := ch.Members.Value(msg.Prefix.Name) + if memberships != nil { + ch.Members.Delete(msg.Prefix.Name) + ch.Members.SetValue(newNick, memberships) + uc.appendLog(ch.Name, msg) + } + } + + if !me { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateNick() + }) + uc.updateMonitor() + } + case "SETNAME": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newRealname string + if err := parseMessageParams(msg, &newRealname); err != nil { + return err + } + + // TODO: consider appending this message to logs + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname) + uc.realname = newRealname + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateRealname() + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case "JOIN": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("joined channel %q", ch) + members := membershipsCasemapMap{newCasemapMap(0)} + members.casemap = uc.network.casemap + uc.channels.SetValue(ch, &upstreamChannel{ + Name: ch, + conn: uc, + Members: members, + }) + uc.updateChannelAutoDetach(ch) + + uc.SendMessage(ctx, &irc.Message{ + Command: "MODE", + Params: []string{ch}, + }) + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.SetValue(msg.Prefix.Name, &memberships{}) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "PART": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("parted channel %q", ch) + uch := uc.channels.Value(ch) + if uch != nil { + uc.channels.Delete(ch) + uch.updateAutoDetach(0) + } + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.Delete(msg.Prefix.Name) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "KICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channel, user string + if err := parseMessageParams(msg, &channel, &user); err != nil { + return err + } + + if uc.isOurNick(user) { + uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) + uc.channels.Delete(channel) + } else { + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + ch.Members.Delete(user) + } + + uc.produce(channel, msg, nil) + case "QUIT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("quit") + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if ch.Members.Has(msg.Prefix.Name) { + ch.Members.Delete(msg.Prefix.Name) + + uc.appendLog(ch.Name, msg) + } + } + + if msg.Prefix.Name != uc.nick { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case irc.RPL_TOPIC, irc.RPL_NOTOPIC: + var name, topic string + if err := parseMessageParams(msg, nil, &name, &topic); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if msg.Command == irc.RPL_TOPIC { + ch.Topic = topic + } else { + ch.Topic = "" + } + case "TOPIC": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if len(msg.Params) > 1 { + ch.Topic = msg.Params[1] + ch.TopicWho = msg.Prefix.Copy() + ch.TopicTime = time.Now() // TODO use msg.Tags["time"] + } else { + ch.Topic = "" + } + uc.produce(ch.Name, msg, nil) + case "MODE": + var name, modeStr string + if err := parseMessageParams(msg, &name, &modeStr); err != nil { + return err + } + + if !uc.isChannel(name) { // user mode change + if name != uc.nick { + return fmt.Errorf("received MODE message for unknown nick %q", name) + } + + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + } else { // channel mode change + ch, err := uc.getChannel(name) + if err != nil { + return err + } + + needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:]) + if err != nil { + return err + } + + uc.appendLog(ch.Name, msg) + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + params := make([]string, len(msg.Params)) + params[0] = dc.marshalEntity(uc.network, name) + params[1] = modeStr + + copy(params[2:], msg.Params[2:]) + for i, modeParam := range params[2:] { + if _, ok := needMarshaling[i]; ok { + params[2+i] = dc.marshalEntity(uc.network, modeParam) + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "MODE", + Params: params, + }) + }) + } + } + case irc.RPL_UMODEIS: + if err := parseMessageParams(msg, nil); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + uc.modes = "" + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + case irc.RPL_CHANNELMODEIS: + var channel string + if err := parseMessageParams(msg, nil, &channel); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 2 { + modeStr = msg.Params[2] + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstMode := ch.modes == nil + ch.modes = make(map[byte]string) + if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil { + return err + } + + c := uc.network.channels.Value(channel) + if firstMode && (c == nil || !c.Detached) { + modeStr, modeParams := ch.modes.Format() + + uc.forEachDownstream(func(dc *downstreamConn) { + params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + }) + } + case rpl_creationtime: + var channel, creationTime string + if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstCreationTime := ch.creationTime == "" + ch.creationTime = creationTime + + c := uc.network.channels.Value(channel) + if firstCreationTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime}, + }) + }) + } + case rpl_topicwhotime: + var channel, who, timeStr string + if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstTopicWhoTime := ch.TopicWho == nil + ch.TopicWho = irc.ParsePrefix(who) + sec, err := strconv.ParseInt(timeStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse topic time: %v", err) + } + ch.TopicTime = time.Unix(sec, 0) + + c := uc.network.channels.Value(channel) + if firstTopicWhoTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{ + dc.nick, + dc.marshalEntity(uc.network, ch.Name), + topicWho.String(), + timeStr, + }, + }) + }) + } + case irc.RPL_LIST: + var channel, clients, topic string + if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LIST, + Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic}, + }) + case irc.RPL_LISTEND: + dc, cmd := uc.dequeueCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "End of /LIST"}, + }) + case irc.RPL_NAMREPLY: + var name, statusStr, members string + if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + members := splitSpace(members) + for i, member := range members { + memberships, nick := uc.parseMembershipPrefix(member) + members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick) + } + memberStr := strings.Join(members, " ") + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, statusStr, channel, memberStr}, + }) + }) + return nil + } + + status, err := parseChannelStatus(statusStr) + if err != nil { + return err + } + ch.Status = status + + for _, s := range splitSpace(members) { + memberships, nick := uc.parseMembershipPrefix(s) + ch.Members.SetValue(nick, memberships) + } + case irc.RPL_ENDOFNAMES: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, channel, "End of /NAMES list"}, + }) + }) + return nil + } + + if ch.complete { + return fmt.Errorf("received unexpected RPL_ENDOFNAMES") + } + ch.complete = true + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + forwardChannel(ctx, dc, ch) + }) + } + case irc.RPL_WHOREPLY: + var channel, username, host, server, nick, flags, trailing string + if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO") + } else if dc == nil { + return nil + } + + if channel != "*" { + channel = dc.marshalEntity(uc.network, channel) + } + nick = dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOREPLY, + Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing}, + }) + case rpl_whospcrpl: + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO") + } else if dc == nil { + return nil + } + + // Only supported in single-upstream mode, so forward as-is + dc.SendMessage(msg) + case irc.RPL_ENDOFWHO: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + dc, cmd := uc.dequeueCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO") + } else if dc == nil { + return nil + } + + mask := "*" + if len(cmd.Params) > 0 { + mask = cmd.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "End of /WHO list"}, + }) + case irc.RPL_WHOISUSER: + var nick, username, host, realname string + if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, nick, username, host, "*", realname}, + }) + }) + case irc.RPL_WHOISSERVER: + var nick, server, serverInfo string + if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, nick, server, serverInfo}, + }) + }) + case irc.RPL_WHOISOPERATOR: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, nick, "is an IRC operator"}, + }) + }) + case irc.RPL_WHOISIDLE: + var nick string + if err := parseMessageParams(msg, nil, &nick, nil); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + params := []string{dc.nick, nick} + params = append(params, msg.Params[2:]...) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISIDLE, + Params: params, + }) + }) + case irc.RPL_WHOISCHANNELS: + var nick, channelList string + if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil { + return err + } + channels := splitSpace(channelList) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + channelList := make([]string, len(channels)) + for i, channel := range channels { + prefix, channel := uc.parseMembershipPrefix(channel) + channel = dc.marshalEntity(uc.network, channel) + channelList[i] = prefix.Format(dc) + channel + } + channels := strings.Join(channelList, " ") + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISCHANNELS, + Params: []string{dc.nick, nick, channels}, + }) + }) + case irc.RPL_ENDOFWHOIS: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, nick, "End of /WHOIS list"}, + }) + }) + case "INVITE": + var nick, channel string + if err := parseMessageParams(msg, &nick, &channel); err != nil { + return err + } + + weAreInvited := uc.isOurNick(nick) + + uc.forEachDownstream(func(dc *downstreamConn) { + if !weAreInvited && !dc.caps["invite-notify"] { + return + } + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "INVITE", + Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_INVITING: + var nick, channel string + if err := parseMessageParams(msg, nil, &nick, &channel); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_INVITING, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: + var targetsStr string + if err := parseMessageParams(msg, nil, &targetsStr); err != nil { + return err + } + targets := strings.Split(targetsStr, ",") + + online := msg.Command == irc.RPL_MONONLINE + for _, target := range targets { + prefix := irc.ParsePrefix(target) + uc.monitored.SetValue(prefix.Name, online) + } + + // Check if the nick we want is now free + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if !online && uc.nickCM != wantNickCM { + found := false + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if uc.network.casemap(prefix.Name) == wantNickCM { + found = true + break + } + } + if found { + uc.logger.Printf("desired nick %q is now available", wantNick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{wantNick}, + }) + } + } + + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if dc.monitored.Has(prefix.Name) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, target}, + }) + } + } + }) + case irc.ERR_MONLISTFULL: + var limit, targetsStr string + if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil { + return err + } + + targets := strings.Split(targetsStr, ",") + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + if dc.monitored.Has(target) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, limit, target}, + }) + } + } + }) + case irc.RPL_AWAY: + var nick, reason string + if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_AWAY, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason}, + }) + }) + case "AWAY", "ACCOUNT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST: + var channel, mask string + if err := parseMessageParams(msg, nil, &channel, &mask); err != nil { + return err + } + var addNick, addTime string + if len(msg.Params) >= 5 { + addNick = msg.Params[3] + addTime = msg.Params[4] + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, channel) + + var params []string + if addNick != "" && addTime != "" { + addNick := dc.marshalEntity(uc.network, addNick) + params = []string{dc.nick, channel, mask, addNick, addTime} + } else { + params = []string{dc.nick, channel, mask} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST: + var channel, trailing string + if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + upstreamChannel := dc.marshalEntity(uc.network, channel) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, upstreamChannel, trailing}, + }) + }) + case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN: + var command, reason string + if err := parseMessageParams(msg, nil, &command, &reason); err != nil { + return err + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, command, reason}, + }) + }) + case "FAIL": + var command, code string + if err := parseMessageParams(msg, &command, &code); err != nil { + return err + } + + if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" { + return registrationError{msg} + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case "ACK": + // Ignore + case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: + // Ignore + case irc.RPL_YOURHOST, irc.RPL_CREATED: + // Ignore + case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: + fallthrough + case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: + fallthrough + case rpl_localusers, rpl_globalusers: + fallthrough + case irc.RPL_MOTDSTART, irc.RPL_MOTD: + // Ignore these messages if they're part of the initial registration + // message burst. Forward them if the user explicitly asked for them. + if !uc.gotMotd { + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_LISTSTART: + // Ignore + case "ERROR": + var text string + if err := parseMessageParams(msg, &text); err != nil { + return err + } + return fmt.Errorf("fatal server error: %v", text) + case irc.ERR_NICKNAMEINUSE: + // At this point, we haven't received ISUPPORT so we don't know the + // maximum nickname length or whether the server supports MONITOR. Many + // servers have NICKLEN=30 so let's just use that. + if !uc.registered && len(uc.nick)+1 < 30 { + uc.nick = uc.nick + "_" + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + return nil + } + fallthrough + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP: + if !uc.registered { + return registrationError{msg} + } + fallthrough + default: + uc.logger.Printf("unhandled message: %v", msg) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + // best effort marshaling for unknown messages, replies and errors: + // most numerics start with the user nick, marshal it if that's the case + // otherwise, conservately keep the params without marshaling + params := msg.Params + if _, err := strconv.Atoi(msg.Command); err == nil { // numeric + if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) { + params[0] = dc.nick + } + } + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + } + return nil +} + +func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) { + if uc.network.detachedMessageNeedsRelay(ch, msg) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.relayDetachedMessage(uc.network, msg) + }) + } + if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { + uc.network.attach(ctx, ch) + if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) + } + } +} + +func (uc *upstreamConn) handleChanModes(s string) error { + parts := strings.SplitN(s, ",", 5) + if len(parts) < 4 { + return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s) + } + modes := make(map[byte]channelModeType) + for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} { + for j := 0; j < len(parts[i]); j++ { + mode := parts[i][j] + modes[mode] = mt + } + } + uc.availableChannelModes = modes + return nil +} + +func (uc *upstreamConn) handleMemberships(s string) error { + if s == "" { + uc.availableMemberships = nil + return nil + } + + if s[0] != '(' { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + sep := strings.IndexByte(s, ')') + if sep < 0 || len(s) != sep*2 { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + memberships := make([]membership, len(s)/2-1) + for i := range memberships { + memberships[i] = membership{ + Mode: s[i+1], + Prefix: s[sep+i+1], + } + } + uc.availableMemberships = memberships + return nil +} + +func (uc *upstreamConn) handleSupportedCaps(capsStr string) { + caps := strings.Fields(capsStr) + for _, s := range caps { + kv := strings.SplitN(s, "=", 2) + k := strings.ToLower(kv[0]) + var v string + if len(kv) == 2 { + v = kv[1] + } + uc.supportedCaps[k] = v + } +} + +func (uc *upstreamConn) requestCaps() { + var requestCaps []string + for c := range permanentUpstreamCaps { + if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { + requestCaps = append(requestCaps, c) + } + } + + if len(requestCaps) == 0 { + return + } + + uc.SendMessage(context.TODO(), &irc.Message{ + Command: "CAP", + Params: []string{"REQ", strings.Join(requestCaps, " ")}, + }) +} + +func (uc *upstreamConn) supportsSASL(mech string) bool { + v, ok := uc.supportedCaps["sasl"] + if !ok { + return false + } + + if v == "" { + return true + } + + mechanisms := strings.Split(v, ",") + for _, mech := range mechanisms { + if strings.EqualFold(mech, mech) { + return true + } + } + return false +} + +func (uc *upstreamConn) requestSASL() bool { + if uc.network.SASL.Mechanism == "" { + return false + } + return uc.supportsSASL(uc.network.SASL.Mechanism) +} + +func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error { + uc.caps[name] = ok + + switch name { + case "sasl": + if !uc.requestSASL() { + return nil + } + if !ok { + uc.logger.Printf("server refused to acknowledge the SASL capability") + return nil + } + + auth := &uc.network.SASL + switch auth.Mechanism { + case "PLAIN": + uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) + uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password) + case "EXTERNAL": + uc.logger.Printf("starting SASL EXTERNAL authentication") + uc.saslClient = sasl.NewExternalClient("") + default: + return fmt.Errorf("unsupported SASL mechanism %q", name) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{auth.Mechanism}, + }) + default: + if permanentUpstreamCaps[name] { + break + } + uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name) + } + return nil +} + +func splitSpace(s string) []string { + return strings.FieldsFunc(s, func(r rune) bool { + return r == ' ' + }) +} + +func (uc *upstreamConn) register(ctx context.Context) { + uc.nick = GetNick(&uc.user.User, &uc.network.Network) + uc.nickCM = uc.network.casemap(uc.nick) + uc.username = GetUsername(&uc.user.User, &uc.network.Network) + uc.realname = GetRealname(&uc.user.User, &uc.network.Network) + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"LS", "302"}, + }) + + if uc.network.Pass != "" { + uc.SendMessage(ctx, &irc.Message{ + Command: "PASS", + Params: []string{uc.network.Pass}, + }) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + uc.SendMessage(ctx, &irc.Message{ + Command: "USER", + Params: []string{uc.username, "0", "*", uc.realname}, + }) +} + +func (uc *upstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := uc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error { + for !uc.registered { + msg, err := uc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read message: %v", err) + } + + if err := uc.handleMessage(ctx, msg); err != nil { + if _, ok := err.(registrationError); ok { + return err + } else { + msg.Tags = nil // prevent message tags from cluttering logs + return fmt.Errorf("failed to handle message %q: %v", msg, err) + } + } + } + + for _, command := range uc.network.ConnectCommands { + m, err := irc.ParseMessage(command) + if err != nil { + uc.logger.Printf("failed to parse connect command %q: %v", command, err) + } else { + uc.SendMessage(ctx, m) + } + } + + return nil +} + +func (uc *upstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := uc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventUpstreamMessage{msg, uc} + } + + return nil +} + +func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { + if !uc.caps["message-tags"] { + msg = msg.Copy() + msg.Tags = nil + } + + uc.conn.SendMessage(ctx, msg) +} + +func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { + if uc.caps["labeled-response"] { + if msg.Tags == nil { + msg.Tags = make(map[string]irc.TagValue) + } + msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID)) + uc.nextLabelID++ + } + uc.SendMessage(ctx, msg) +} + +// appendLog appends a message to the log file. +// +// The internal message ID is returned. If the message isn't recorded in the +// log file, an empty string is returned. +func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) { + if uc.user.msgStore == nil { + return "" + } + + // Don't store messages with a server mask target + if strings.HasPrefix(entity, "$") { + return "" + } + + entityCM := uc.network.casemap(entity) + if entityCM == "nickserv" { + // The messages sent/received from NickServ may contain + // security-related information (like passwords). Don't store these. + return "" + } + + if !uc.network.delivered.HasTarget(entity) { + // This is the first message we receive from this target. Save the last + // message ID in delivery receipts, so that we can send the new message + // in the backlog if an offline client reconnects. + lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now()) + if err != nil { + uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) + return "" + } + + uc.network.delivered.ForEachClient(func(clientName string) { + uc.network.delivered.StoreID(entity, clientName, lastID) + }) + } + + msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg) + if err != nil { + uc.logger.Printf("failed to append message to store: %v", err) + return "" + } + + return msgID +} + +// produce appends a message to the logs and forwards it to connected downstream +// connections. +// +// If origin is not nil and origin doesn't support echo-message, the message is +// forwarded to all connections except origin. +func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) { + var msgID string + if target != "" { + msgID = uc.appendLog(target, msg) + } + + // Don't forward messages if it's a detached channel + ch := uc.network.channels.Value(target) + detached := ch != nil && ch.Detached + + uc.forEachDownstream(func(dc *downstreamConn) { + if !detached && (dc != origin || dc.caps["echo-message"]) { + dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID) + } else { + dc.advanceMessageWithID(msg, msgID) + } + }) +} + +func (uc *upstreamConn) updateAway() { + ctx := context.TODO() + + away := true + uc.forEachDownstream(func(*downstreamConn) { + away = false + }) + if away == uc.away { + return + } + if away { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + Params: []string{"Auto away"}, + }) + } else { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + }) + } + uc.away = away +} + +func (uc *upstreamConn) updateChannelAutoDetach(name string) { + uch := uc.channels.Value(name) + if uch == nil { + return + } + ch := uc.network.channels.Value(name) + if ch == nil || ch.Detached { + return + } + uch.updateAutoDetach(ch.DetachAfter) +} + +func (uc *upstreamConn) updateMonitor() { + if _, ok := uc.isupport["MONITOR"]; !ok { + return + } + + ctx := context.TODO() + + add := make(map[string]struct{}) + var addList []string + seen := make(map[string]struct{}) + uc.forEachDownstream(func(dc *downstreamConn) { + for targetCM := range dc.monitored.innerMap { + if !uc.monitored.Has(targetCM) { + if _, ok := add[targetCM]; !ok { + addList = append(addList, targetCM) + add[targetCM] = struct{}{} + } + } else { + seen[targetCM] = struct{}{} + } + } + }) + + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) { + addList = append(addList, wantNickCM) + add[wantNickCM] = struct{}{} + } + + removeAll := true + var removeList []string + for targetCM, entry := range uc.monitored.innerMap { + if _, ok := seen[targetCM]; ok { + removeAll = false + } else { + removeList = append(removeList, entry.originalKey) + } + } + + // TODO: better handle the case where len(uc.monitored) + len(addList) + // exceeds the limit, probably by immediately sending ERR_MONLISTFULL? + + if removeAll && len(addList) == 0 && len(removeList) > 0 { + // Optimization when the last MONITOR-aware downstream disconnects + uc.SendMessage(ctx, &irc.Message{ + Command: "MONITOR", + Params: []string{"C"}, + }) + } else { + msgs := generateMonitor("-", removeList) + msgs = append(msgs, generateMonitor("+", addList)...) + for _, msg := range msgs { + uc.SendMessage(ctx, msg) + } + } + + for _, target := range removeList { + uc.monitored.Delete(target) + } +} diff --git a/branches/origin/user.go b/branches/origin/user.go new file mode 100644 index 0000000..0b523ce --- /dev/null +++ b/branches/origin/user.go @@ -0,0 +1,1068 @@ +package suika + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "math/big" + "net" + "sort" + "strings" + "time" + + "gopkg.in/irc.v3" +) + +type event interface{} + +type eventUpstreamMessage struct { + msg *irc.Message + uc *upstreamConn +} + +type eventUpstreamConnectionError struct { + net *network + err error +} + +type eventUpstreamConnected struct { + uc *upstreamConn +} + +type eventUpstreamDisconnected struct { + uc *upstreamConn +} + +type eventUpstreamError struct { + uc *upstreamConn + err error +} + +type eventDownstreamMessage struct { + msg *irc.Message + dc *downstreamConn +} + +type eventDownstreamConnected struct { + dc *downstreamConn +} + +type eventDownstreamDisconnected struct { + dc *downstreamConn +} + +type eventChannelDetach struct { + uc *upstreamConn + name string +} + +type eventBroadcast struct { + msg *irc.Message +} + +type eventStop struct{} + +type eventUserUpdate struct { + password *string + admin *bool + done chan error +} + +type deliveredClientMap map[string]string // client name -> msg ID + +type deliveredStore struct { + m deliveredCasemapMap +} + +func newDeliveredStore() deliveredStore { + return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}} +} + +func (ds deliveredStore) HasTarget(target string) bool { + return ds.m.Value(target) != nil +} + +func (ds deliveredStore) LoadID(target, clientName string) string { + clients := ds.m.Value(target) + if clients == nil { + return "" + } + return clients[clientName] +} + +func (ds deliveredStore) StoreID(target, clientName, msgID string) { + clients := ds.m.Value(target) + if clients == nil { + clients = make(deliveredClientMap) + ds.m.SetValue(target, clients) + } + clients[clientName] = msgID +} + +func (ds deliveredStore) ForEachTarget(f func(target string)) { + for _, entry := range ds.m.innerMap { + f(entry.originalKey) + } +} + +func (ds deliveredStore) ForEachClient(f func(clientName string)) { + clients := make(map[string]struct{}) + for _, entry := range ds.m.innerMap { + delivered := entry.value.(deliveredClientMap) + for clientName := range delivered { + clients[clientName] = struct{}{} + } + } + + for clientName := range clients { + f(clientName) + } +} + +type network struct { + Network + user *user + logger Logger + stopped chan struct{} + + conn *upstreamConn + channels channelCasemapMap + delivered deliveredStore + lastError error + casemap casemapping +} + +func newNetwork(user *user, record *Network, channels []Channel) *network { + logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} + + m := channelCasemapMap{newCasemapMap(0)} + for _, ch := range channels { + ch := ch + m.SetValue(ch.Name, &ch) + } + + return &network{ + Network: *record, + user: user, + logger: logger, + stopped: make(chan struct{}), + channels: m, + delivered: newDeliveredStore(), + casemap: casemapRFC1459, + } +} + +func (net *network) forEachDownstream(f func(*downstreamConn)) { + net.user.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + if dc.network != nil && dc.network != net { + return + } + f(dc) + }) +} + +func (net *network) isStopped() bool { + select { + case <-net.stopped: + return true + default: + return false + } +} + +func userIdent(u *User) string { + // The ident is a string we will send to upstream servers in clear-text. + // For privacy reasons, make sure it doesn't expose any meaningful user + // metadata. We just use the base64-encoded hashed ID, so that people don't + // start relying on the string being an integer or following a pattern. + var b [64]byte + binary.LittleEndian.PutUint64(b[:], uint64(u.ID)) + h := sha256.Sum256(b[:]) + return hex.EncodeToString(h[:16]) +} + +func (net *network) run() { + if !net.Enabled { + return + } + + var lastTry time.Time + backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter) + for { + if net.isStopped() { + return + } + + delay := backoff.Next() - time.Now().Sub(lastTry) + if delay > 0 { + net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr) + time.Sleep(delay) + } + lastTry = time.Now() + + + uc, err := connectToUpstream(context.TODO(), net) + if err != nil { + net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)} + continue + } + + uc.register(context.TODO()) + if err := uc.runUntilRegistered(context.TODO()); err != nil { + text := err.Error() + temp := true + if regErr, ok := err.(registrationError); ok { + text = regErr.Reason() + temp = regErr.Temporary() + } + uc.logger.Printf("failed to register: %v", text) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)} + uc.Close() + if !temp { + return + } + continue + } + + // TODO: this is racy with net.stopped. If the network is stopped + // before the user goroutine receives eventUpstreamConnected, the + // connection won't be closed. + net.user.events <- eventUpstreamConnected{uc} + if err := uc.readMessages(net.user.events); err != nil { + uc.logger.Printf("failed to handle messages: %v", err) + net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)} + } + uc.Close() + net.user.events <- eventUpstreamDisconnected{uc} + + backoff.Reset() + } +} + +func (net *network) stop() { + if !net.isStopped() { + close(net.stopped) + } + + if net.conn != nil { + net.conn.Close() + } +} + +func (net *network) detach(ch *Channel) { + if ch.Detached { + return + } + + net.logger.Printf("detaching channel %q", ch.Name) + + ch.Detached = true + + if net.user.msgStore != nil { + nameCM := net.casemap(ch.Name) + lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now()) + if err != nil { + net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err) + } + ch.DetachedInternalMsgID = lastID + } + + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "PART", + Params: []string{dc.marshalEntity(net, ch.Name), "Detach"}, + }) + }) +} + +func (net *network) attach(ctx context.Context, ch *Channel) { + if !ch.Detached { + return + } + + net.logger.Printf("attaching channel %q", ch.Name) + + detachedMsgID := ch.DetachedInternalMsgID + ch.Detached = false + ch.DetachedInternalMsgID = "" + + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + + net.conn.updateChannelAutoDetach(ch.Name) + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(net, ch.Name)}, + }) + + if uch != nil { + forwardChannel(ctx, dc, uch) + } + + if detachedMsgID != "" { + dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID) + } + }) +} + +func (net *network) deleteChannel(ctx context.Context, name string) error { + ch := net.channels.Value(name) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil { + return err + } + net.channels.Delete(name) + return nil +} + +func (net *network) updateCasemapping(newCasemap casemapping) { + net.casemap = newCasemap + net.channels.SetCasemapping(newCasemap) + net.delivered.m.SetCasemapping(newCasemap) + if uc := net.conn; uc != nil { + uc.channels.SetCasemapping(newCasemap) + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.Members.SetCasemapping(newCasemap) + } + uc.monitored.SetCasemapping(newCasemap) + } + net.forEachDownstream(func(dc *downstreamConn) { + dc.monitored.SetCasemapping(newCasemap) + }) +} + +func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) { + if !net.user.hasPersistentMsgStore() { + return + } + + var receipts []DeliveryReceipt + net.delivered.ForEachTarget(func(target string) { + msgID := net.delivered.LoadID(target, clientName) + if msgID == "" { + return + } + receipts = append(receipts, DeliveryReceipt{ + Target: target, + InternalMsgID: msgID, + }) + }) + + if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil { + net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err) + } +} + +func (net *network) isHighlight(msg *irc.Message) bool { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return false + } + + text := msg.Params[1] + + nick := net.Nick + if net.conn != nil { + nick = net.conn.nick + } + + // TODO: use case-mapping aware comparison here + return msg.Prefix.Name != nick && isHighlight(text, nick) +} + +func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool { + highlight := net.isHighlight(msg) + return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight) +} + +func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) { + // User may have e.g. EXTERNAL mechanism configured. We do not want to + // automatically erase the key pair or any other credentials. + if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" { + return + } + + net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username) + net.SASL.Mechanism = "PLAIN" + net.SASL.Plain.Username = username + net.SASL.Plain.Password = password + if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil { + net.logger.Printf("failed to save SASL PLAIN credentials: %v", err) + } +} + +type user struct { + User + srv *Server + logger Logger + + events chan event + done chan struct{} + + networks []*network + downstreamConns []*downstreamConn + msgStore messageStore +} + +func newUser(srv *Server, record *User) *user { + logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} + + var msgStore messageStore + if logPath := srv.Config().LogPath; logPath != "" { + msgStore = newFSMessageStore(logPath, record) + } else { + msgStore = newMemoryMessageStore() + } + + return &user{ + User: *record, + srv: srv, + logger: logger, + events: make(chan event, 64), + done: make(chan struct{}), + msgStore: msgStore, + } +} + +func (u *user) forEachUpstream(f func(uc *upstreamConn)) { + for _, network := range u.networks { + if network.conn == nil { + continue + } + f(network.conn) + } +} + +func (u *user) forEachDownstream(f func(dc *downstreamConn)) { + for _, dc := range u.downstreamConns { + f(dc) + } +} + +func (u *user) getNetwork(name string) *network { + for _, network := range u.networks { + if network.Addr == name { + return network + } + if network.Name != "" && network.Name == name { + return network + } + } + return nil +} + +func (u *user) getNetworkByID(id int64) *network { + for _, net := range u.networks { + if net.ID == id { + return net + } + } + return nil +} + +func (u *user) run() { + defer func() { + if u.msgStore != nil { + if err := u.msgStore.Close(); err != nil { + u.logger.Printf("failed to close message store for user %q: %v", u.Username, err) + } + } + close(u.done) + }() + + networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID) + if err != nil { + u.logger.Printf("failed to list networks for user %q: %v", u.Username, err) + return + } + + sort.Slice(networks, func(i, j int) bool { + return networks[i].ID < networks[j].ID + }) + + for _, record := range networks { + record := record + channels, err := u.srv.db.ListChannels(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err) + continue + } + + network := newNetwork(u, &record, channels) + u.networks = append(u.networks, network) + + if u.hasPersistentMsgStore() { + receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err) + return + } + + for _, rcpt := range receipts { + network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID) + } + } + + go network.run() + } + + for e := range u.events { + switch e := e.(type) { + case eventUpstreamConnected: + uc := e.uc + + uc.network.conn = uc + + uc.updateAway() + uc.updateMonitor() + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + }) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=connected"}, + }) + } + }) + uc.network.lastError = nil + case eventUpstreamDisconnected: + u.handleUpstreamDisconnected(e.uc) + case eventUpstreamConnectionError: + net := e.net + + stopped := false + select { + case <-net.stopped: + stopped = true + default: + } + + if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) { + net.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err)) + }) + } + net.lastError = e.err + case eventUpstreamError: + uc := e.uc + + uc.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err)) + }) + uc.network.lastError = e.err + case eventUpstreamMessage: + msg, uc := e.msg, e.uc + if uc.isClosed() { + uc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + if err := uc.handleMessage(context.TODO(), msg); err != nil { + uc.logger.Printf("failed to handle message %q: %v", msg, err) + } + case eventChannelDetach: + uc, name := e.uc, e.name + c := uc.network.channels.Value(name) + if c == nil || c.Detached { + continue + } + uc.network.detach(c) + if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil { + u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err) + } + case eventDownstreamConnected: + dc := e.dc + + if dc.network != nil { + dc.monitored.SetCasemapping(dc.network.casemap) + } + + if err := dc.welcome(context.TODO()); err != nil { + dc.logger.Printf("failed to handle new registered connection: %v", err) + break + } + + u.downstreamConns = append(u.downstreamConns, dc) + + dc.forEachNetwork(func(network *network) { + if network.lastError != nil { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError)) + } + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.updateAway() + }) + case eventDownstreamDisconnected: + dc := e.dc + + for i := range u.downstreamConns { + if u.downstreamConns[i] == dc { + u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) + break + } + } + + dc.forEachNetwork(func(net *network) { + net.storeClientDeliveryReceipts(context.TODO(), dc.clientName) + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.cancelPendingCommandsByDownstreamID(dc.id) + uc.updateAway() + uc.updateMonitor() + }) + case eventDownstreamMessage: + msg, dc := e.msg, e.dc + if dc.isClosed() { + dc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + err := dc.handleMessage(context.TODO(), msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + dc.logger.Printf("failed to handle message %q: %v", msg, err) + dc.Close() + } + case eventBroadcast: + msg := e.msg + u.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case eventUserUpdate: + // copy the user record because we'll mutate it + record := u.User + + if e.password != nil { + record.Password = *e.password + } + if e.admin != nil { + record.Admin = *e.admin + } + + e.done <- u.updateUser(context.TODO(), &record) + + // If the password was updated, kill all downstream connections to + // force them to re-authenticate with the new credentials. + if e.password != nil { + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + } + case eventStop: + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + for _, n := range u.networks { + n.stop() + + n.delivered.ForEachClient(func(clientName string) { + n.storeClientDeliveryReceipts(context.TODO(), clientName) + }) + } + return + default: + panic(fmt.Sprintf("received unknown event type: %T", e)) + } + } +} + +func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { + uc.network.conn = nil + + uc.abortPendingCommands() + + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.updateAutoDetach(0) + } + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + + // If the network has been removed, don't send a state change notification + found := false + for _, net := range u.networks { + if net == uc.network { + found = true + break + } + } + if !found { + return + } + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=disconnected"}, + }) + } + }) + + if uc.network.lastError == nil { + uc.forEachDownstream(func(dc *downstreamConn) { + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) + } + }) + } +} + +func (u *user) addNetwork(network *network) { + u.networks = append(u.networks, network) + + sort.Slice(u.networks, func(i, j int) bool { + return u.networks[i].ID < u.networks[j].ID + }) + + go network.run() +} + +func (u *user) removeNetwork(network *network) { + network.stop() + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.Close() + } + }) + + for i, net := range u.networks { + if net == network { + u.networks = append(u.networks[:i], u.networks[i+1:]...) + return + } + } + + panic("tried to remove a non-existing network") +} + +func (u *user) checkNetwork(record *Network) error { + url, err := record.URL() + if err != nil { + return err + } + if url.User != nil { + return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme) + } + if url.RawQuery != "" { + return fmt.Errorf("%v:// URL must not have query values", url.Scheme) + } + if url.Fragment != "" { + return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme) + } + switch url.Scheme { + case "ircs", "irc": + if url.Host == "" { + return fmt.Errorf("%v:// URL must have a host", url.Scheme) + } + if url.Path != "" { + return fmt.Errorf("%v:// URL must not have a path", url.Scheme) + } + case "irc+unix", "unix": + if url.Host != "" { + return fmt.Errorf("%v:// URL must not have a host", url.Scheme) + } + if url.Path == "" { + return fmt.Errorf("%v:// URL must have a path", url.Scheme) + } + default: + return fmt.Errorf("unknown URL scheme %q", url.Scheme) + } + + if record.GetName() == "" { + return fmt.Errorf("network name cannot be empty") + } + if strings.HasPrefix(record.GetName(), "-") { + // Can be mixed up with flags when sending commands to the service + return fmt.Errorf("network name cannot start with a dash character") + } + + for _, net := range u.networks { + if net.GetName() == record.GetName() && net.ID != record.ID { + return fmt.Errorf("a network with the name %q already exists", record.GetName()) + } + } + + return nil +} + +func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID != 0 { + panic("tried creating an already-existing network") + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max { + return nil, fmt.Errorf("maximum number of networks reached") + } + + network := newNetwork(u, record, nil) + err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network) + if err != nil { + return nil, err + } + + u.addNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return network, nil +} + +func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID == 0 { + panic("tried updating a new network") + } + + // If the realname is reset to the default, just wipe the per-network + // setting + if record.Realname == u.Realname { + record.Realname = "" + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + network := u.getNetworkByID(record.ID) + if network == nil { + panic("tried updating a non-existing network") + } + + if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil { + return nil, err + } + + // Most network changes require us to re-connect to the upstream server + + channels := make([]Channel, 0, network.channels.Len()) + for _, entry := range network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, *ch) + } + + updatedNetwork := newNetwork(u, record, channels) + + // If we're currently connected, disconnect and perform the necessary + // bookkeeping + if network.conn != nil { + network.stop() + // Note: this will set network.conn to nil + u.handleUpstreamDisconnected(network.conn) + } + + // Patch downstream connections to use our fresh updated network + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.network = updatedNetwork + } + }) + + // We need to remove the network after patching downstream connections, + // otherwise they'll get closed + u.removeNetwork(network) + + // The filesystem message store needs to be notified whenever the network + // is renamed + fsMsgStore, isFS := u.msgStore.(*fsMessageStore) + if isFS && updatedNetwork.GetName() != network.GetName() { + if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil { + network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err) + } + } + + // This will re-connect to the upstream server + u.addNetwork(updatedNetwork) + + // TODO: only broadcast attributes that have changed + idStr := fmt.Sprintf("%v", updatedNetwork.ID) + attrs := getNetworkAttrs(updatedNetwork) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return updatedNetwork, nil +} + +func (u *user) deleteNetwork(ctx context.Context, id int64) error { + network := u.getNetworkByID(id) + if network == nil { + panic("tried deleting a non-existing network") + } + + if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil { + return err + } + + u.removeNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, "*"}, + }) + } + }) + + return nil +} + +func (u *user) updateUser(ctx context.Context, record *User) error { + if u.ID != record.ID { + panic("ID mismatch when updating user") + } + + realnameUpdated := u.Realname != record.Realname + if err := u.srv.db.StoreUser(ctx, record); err != nil { + return fmt.Errorf("failed to update user %q: %v", u.Username, err) + } + u.User = *record + + if realnameUpdated { + // Re-connect to networks which use the default realname + var needUpdate []Network + for _, net := range u.networks { + if net.Realname == "" { + needUpdate = append(needUpdate, net.Network) + } + } + + var netErr error + for _, net := range needUpdate { + if _, err := u.updateNetwork(ctx, &net); err != nil { + netErr = err + } + } + if netErr != nil { + return netErr + } + } + + return nil +} + +func (u *user) stop() { + u.events <- eventStop{} + <-u.done +} + +func (u *user) hasPersistentMsgStore() bool { + if u.msgStore == nil { + return false + } + _, isMem := u.msgStore.(*memoryMessageStore) + return !isMem +} + +// localAddrForHost returns the local address to use when connecting to host. +// A nil address is returned when the OS should automatically pick one. +func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) { + upstreamUserIPs := u.srv.Config().UpstreamUserIPs + if len(upstreamUserIPs) == 0 { + return nil, nil + } + + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) + if err != nil { + return nil, err + } + + wantIPv6 := false + for _, ip := range ips { + if ip.To4() == nil { + wantIPv6 = true + break + } + } + + var ipNet *net.IPNet + for _, in := range upstreamUserIPs { + if wantIPv6 == (in.IP.To4() == nil) { + ipNet = in + break + } + } + if ipNet == nil { + return nil, nil + } + + var ipInt big.Int + ipInt.SetBytes(ipNet.IP) + ipInt.Add(&ipInt, big.NewInt(u.ID+1)) + ip := net.IP(ipInt.Bytes()) + if !ipNet.Contains(ip) { + return nil, fmt.Errorf("IP network %v too small", ipNet) + } + + return &net.TCPAddr{IP: ip}, nil +} diff --git a/branches/origin/version.go b/branches/origin/version.go new file mode 100644 index 0000000..0e5d7a6 --- /dev/null +++ b/branches/origin/version.go @@ -0,0 +1,50 @@ +package suika + +import ( + "fmt" + "runtime/debug" + "strings" +) + +const ( + defaultVersion = "0.0.0" + defaultCommit = "HEAD" + defaultBuild = "0000-01-01:00:00+00:00" +) + +var ( + // Version is the tagged release version in the form .. + // following semantic versioning and is overwritten by the build system. + Version = defaultVersion + + // Commit is the commit sha of the build (normally from Git) and is overwritten + // by the build system. + Commit = defaultCommit + + // Build is the date and time of the build as an RFC3339 formatted string + // and is overwritten by the build system. + Build = defaultBuild +) + +// FullVersion display the full version and build +func FullVersion() string { + var sb strings.Builder + + isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild + + if !isDefault { + sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build)) + } + + if info, ok := debug.ReadBuildInfo(); ok { + if isDefault { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Version)) + } + sb.WriteString(fmt.Sprintf(" %s", info.GoVersion)) + if info.Main.Sum != "" { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum)) + } + } + + return sb.String() +} diff --git a/trunk/.gitignore b/trunk/.gitignore new file mode 100644 index 0000000..4aa22f5 --- /dev/null +++ b/trunk/.gitignore @@ -0,0 +1,5 @@ +vendor +/suika +/suikadb +/suika-znc-import +/suika.db diff --git a/trunk/LICENSE b/trunk/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/trunk/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/trunk/Makefile b/trunk/Makefile new file mode 100644 index 0000000..a77bc97 --- /dev/null +++ b/trunk/Makefile @@ -0,0 +1,45 @@ +GO ?= go +RM ?= rm +GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor +PREFIX ?= /usr/local +BINDIR ?= bin +MANDIR ?= share/man +MKDIR ?= mkdir +CP ?= cp +SYSCONFDIR ?= /etc +ASCIIDOCTOR ?= asciidoctor + +VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"` +COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"` +BRANCH = `git rev-parse --abbrev-ref HEAD` +BUILD = `git show -s --pretty=format:%cI` + +GOARCH ?= amd64 +GOOS ?= linux + +all: build + +build: vendor + ${GO} build ${GOFLAGS} ./cmd/suika + ${GO} build ${GOFLAGS} ./cmd/suikadb + ${GO} build ${GOFLAGS} ./cmd/suika-znc-import +clean: + ${RM} -f suika suikadb suika-znc-import +install: + ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR} + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5 + ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7 + ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika + ${MKDIR} -p ${DESTDIR}/var/lib/suika + ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR} + ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1 + ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1 + ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5 + [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config +test: + go test +vendor: + go mod vendor +.PHONY: build clean install diff --git a/trunk/README.md b/trunk/README.md new file mode 100644 index 0000000..2171ba5 --- /dev/null +++ b/trunk/README.md @@ -0,0 +1,33 @@ +# suika + +[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika) + +A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power) + +- Multi-user +- Support multiple clients for a single user, with proper backlog + synchronization +- Support connecting to multiple upstream servers via a single IRC connection + to the bouncer + +## Building and installing + +Dependencies: + +- Go +- BSD or GNU make + +For end users, a `Makefile` is provided: + + make + doas make install + +For development, you can use `go run ./cmd/suika` as usual. + +## License +AGPLv3, see [LICENSE](LICENSE). + +* Copyright (C) 2020 The soju Contributors +* Copyright (C) 2023-present Izuru Yakumo + +The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT diff --git a/trunk/bridge.go b/trunk/bridge.go new file mode 100644 index 0000000..7bb7aa8 --- /dev/null +++ b/trunk/bridge.go @@ -0,0 +1,116 @@ +package suika + +import ( + "context" + "fmt" + "strconv" + "strings" + + "gopkg.in/irc.v3" +) + +func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) { + if !ch.complete { + panic("Tried to forward a partial channel") + } + + // RPL_NOTOPIC shouldn't be sent on JOIN + if ch.Topic != "" { + sendTopic(dc, ch) + } + + if dc.caps["soju.im/read"] { + channelCM := ch.conn.network.casemap(ch.Name) + r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err) + } else { + timestampStr := "*" + if r != nil { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr}, + }) + } + } + + sendNames(dc, ch) +} + +func sendTopic(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + if ch.Topic != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_TOPIC, + Params: []string{dc.nick, downstreamName, ch.Topic}, + }) + if ch.TopicWho != nil { + topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho) + topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime}, + }) + } + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NOTOPIC, + Params: []string{dc.nick, downstreamName, "No topic is set"}, + }) + } +} + +func sendNames(dc *downstreamConn, ch *upstreamChannel) { + downstreamName := dc.marshalEntity(ch.conn.network, ch.Name) + + emptyNameReply := &irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, ""}, + } + maxLength := maxMessageLength - len(emptyNameReply.String()) + + var buf strings.Builder + for _, entry := range ch.Members.innerMap { + nick := entry.originalKey + memberships := entry.value.(*memberships) + s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick) + + n := buf.Len() + 1 + len(s) + if buf.Len() != 0 && n > maxLength { + // There's not enough space for the next space + nick. + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + buf.Reset() + } + + if buf.Len() != 0 { + buf.WriteByte(' ') + } + buf.WriteString(s) + } + + if buf.Len() != 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()}, + }) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, downstreamName, "End of /NAMES list"}, + }) +} diff --git a/trunk/certfp.go b/trunk/certfp.go new file mode 100644 index 0000000..9180f7c --- /dev/null +++ b/trunk/certfp.go @@ -0,0 +1,72 @@ +package suika + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" +) + +func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) { + var ( + privKey crypto.PrivateKey + pubKey crypto.PublicKey + ) + switch keyType { + case "rsa": + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ecdsa": + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, err + } + privKey = key + pubKey = key.Public() + case "ed25519": + var err error + pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + } + + // Using PKCS#8 allows easier extension for new key types. + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey) + if err != nil { + return nil, nil, err + } + + notBefore := time.Now() + // Lets make a fair assumption nobody will use the same cert for more than 20 years... + notAfter := notBefore.Add(24 * time.Hour * 365 * 20) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, err + } + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: "suika auto-generated certificate"}, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey) + if err != nil { + return nil, nil, err + } + + return privKeyBytes, certBytes, nil +} diff --git a/trunk/cmd/suika-znc-import/main.go b/trunk/cmd/suika-znc-import/main.go new file mode 100644 index 0000000..78fefc6 --- /dev/null +++ b/trunk/cmd/suika-znc-import/main.go @@ -0,0 +1,474 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "net/url" + "os" + "strings" + "unicode" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +const usage = `usage: suika-znc-import [options...] + +Imports configuration from a ZNC file. Users and networks are merged if they +already exist in the suika database. ZNC settings overwrite existing suika +settings. + +Options: + + -help Show this help message + -config Path to suika config file + -user Limit import to username (may be specified multiple times) + -network Limit import to network (may be specified multiple times) +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + users := make(map[string]bool) + networks := make(map[string]bool) + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Var((*stringSetFlag)(&users), "user", "") + flag.Var((*stringSetFlag)(&networks), "network", "") + flag.Parse() + + zncPath := flag.Arg(0) + if zncPath == "" { + flag.Usage() + os.Exit(1) + } + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + ctx := context.Background() + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() + + f, err := os.Open(zncPath) + if err != nil { + log.Fatalf("failed to open ZNC configuration file: %v", err) + } + defer f.Close() + + zp := zncParser{bufio.NewReader(f), 1} + root, err := zp.sectionBody("", "") + if err != nil { + log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err) + } + + l, err := db.ListUsers(ctx) + if err != nil { + log.Fatalf("failed to list users in DB: %v", err) + } + existingUsers := make(map[string]*suika.User, len(l)) + for i, u := range l { + existingUsers[u.Username] = &l[i] + } + + usersCreated := 0 + usersImported := 0 + networksImported := 0 + channelsImported := 0 + root.ForEach("User", func(section *zncSection) { + username := section.Name + if len(users) > 0 && !users[username] { + return + } + usersImported++ + + u, ok := existingUsers[username] + if ok { + log.Printf("user %q: updating existing user", username) + } else { + // "!!" is an invalid crypt format, thus disables password auth + u = &suika.User{Username: username, Password: "!!"} + usersCreated++ + log.Printf("user %q: creating new user", username) + } + + u.Admin = section.Values.Get("Admin") == "true" + + if err := db.StoreUser(ctx, u); err != nil { + log.Fatalf("failed to store user %q: %v", username, err) + } + userID := u.ID + + l, err := db.ListNetworks(ctx, userID) + if err != nil { + log.Fatalf("failed to list networks for user %q: %v", username, err) + } + existingNetworks := make(map[string]*suika.Network, len(l)) + for i, n := range l { + existingNetworks[n.GetName()] = &l[i] + } + + nick := section.Values.Get("Nick") + realname := section.Values.Get("RealName") + ident := section.Values.Get("Ident") + + section.ForEach("Network", func(section *zncSection) { + netName := section.Name + if len(networks) > 0 && !networks[netName] { + return + } + networksImported++ + + logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName) + logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix) + + netNick := section.Values.Get("Nick") + if netNick == "" { + netNick = nick + } + netRealname := section.Values.Get("RealName") + if netRealname == "" { + netRealname = realname + } + netIdent := section.Values.Get("Ident") + if netIdent == "" { + netIdent = ident + } + + for _, name := range section.Values["LoadModule"] { + switch name { + case "sasl": + logger.Printf("warning: SASL credentials not imported") + case "nickserv": + logger.Printf("warning: NickServ credentials not imported") + case "perform": + logger.Printf("warning: \"perform\" plugin commands not imported") + } + } + + u, pass, err := importNetworkServer(section.Values.Get("Server")) + if err != nil { + logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err) + } + + n, ok := existingNetworks[netName] + if ok { + logger.Printf("updating existing network") + } else { + n = &suika.Network{Name: netName} + logger.Printf("creating new network") + } + + n.Addr = u.String() + n.Nick = netNick + n.Username = netIdent + n.Realname = netRealname + n.Pass = pass + n.Enabled = section.Values.Get("IRCConnectEnabled") != "false" + + if err := db.StoreNetwork(ctx, userID, n); err != nil { + logger.Fatalf("failed to store network: %v", err) + } + + l, err := db.ListChannels(ctx, n.ID) + if err != nil { + logger.Fatalf("failed to list channels: %v", err) + } + existingChannels := make(map[string]*suika.Channel, len(l)) + for i, ch := range l { + existingChannels[ch.Name] = &l[i] + } + + section.ForEach("Chan", func(section *zncSection) { + chName := section.Name + + if section.Values.Get("Disabled") == "true" { + logger.Printf("skipping import of disabled channel %q", chName) + return + } + + channelsImported++ + + ch, ok := existingChannels[chName] + if ok { + logger.Printf("channel %q: updating existing channel", chName) + } else { + ch = &suika.Channel{Name: chName} + logger.Printf("channel %q: creating new channel", chName) + } + + ch.Key = section.Values.Get("Key") + ch.Detached = section.Values.Get("Detached") == "true" + + if err := db.StoreChannel(ctx, n.ID, ch); err != nil { + logger.Printf("channel %q: failed to store channel: %v", chName, err) + } + }) + }) + }) + + if err := db.Close(); err != nil { + log.Printf("failed to close database: %v", err) + } + + if usersCreated > 0 { + log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password `") + } + + log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported) +} + +func importNetworkServer(s string) (u *url.URL, pass string, err error) { + parts := strings.Fields(s) + if len(parts) < 2 { + return nil, "", fmt.Errorf("expected space-separated host and port") + } + + scheme := "irc" + host := parts[0] + port := parts[1] + if strings.HasPrefix(port, "+") { + port = port[1:] + scheme = "ircs" + } + + if len(parts) > 2 { + pass = parts[2] + } + + u = &url.URL{ + Scheme: scheme, + Host: host + ":" + port, + } + return u, pass, nil +} + +type zncSection struct { + Type string + Name string + Values zncValues + Children []zncSection +} + +func (s *zncSection) ForEach(typ string, f func(*zncSection)) { + for _, section := range s.Children { + if section.Type == typ { + f(§ion) + } + } +} + +type zncValues map[string][]string + +func (zv zncValues) Get(k string) string { + if len(zv[k]) == 0 { + return "" + } + return zv[k][0] +} + +type zncParser struct { + br *bufio.Reader + line int +} + +func (zp *zncParser) readByte() (byte, error) { + b, err := zp.br.ReadByte() + if b == '\n' { + zp.line++ + } + return b, err +} + +func (zp *zncParser) readRune() (rune, int, error) { + r, n, err := zp.br.ReadRune() + if r == '\n' { + zp.line++ + } + return r, n, err +} + +func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) { + section := &zncSection{Type: typ, Name: name, Values: make(zncValues)} + +Loop: + for { + if err := zp.skipSpace(); err != nil { + return nil, err + } + + b, err := zp.br.Peek(2) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + switch b[0] { + case '<': + if b[1] == '/' { + break Loop + } else { + childType, childName, err := zp.sectionHeader() + if err != nil { + return nil, err + } + child, err := zp.sectionBody(childType, childName) + if err != nil { + return nil, err + } + if footerType, err := zp.sectionFooter(); err != nil { + return nil, err + } else if footerType != childType { + return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType) + } + section.Children = append(section.Children, *child) + } + case '/': + if b[1] == '/' { + if err := zp.skipComment(); err != nil { + return nil, err + } + break + } + fallthrough + default: + k, v, err := zp.keyValuePair() + if err != nil { + return nil, err + } + section.Values[k] = append(section.Values[k], v) + } + } + + return section, nil +} + +func (zp *zncParser) skipSpace() error { + for { + r, _, err := zp.readRune() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if !unicode.IsSpace(r) { + zp.br.UnreadRune() + return nil + } + } +} + +func (zp *zncParser) skipComment() error { + if err := zp.expectRune('/'); err != nil { + return err + } + if err := zp.expectRune('/'); err != nil { + return err + } + + for { + b, err := zp.readByte() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if b == '\n' { + return nil + } + } +} + +func (zp *zncParser) sectionHeader() (string, string, error) { + if err := zp.expectRune('<'); err != nil { + return "", "", err + } + typ, err := zp.readWord(' ') + if err != nil { + return "", "", err + } + name, err := zp.readWord('>') + return typ, name, err +} + +func (zp *zncParser) sectionFooter() (string, error) { + if err := zp.expectRune('<'); err != nil { + return "", err + } + if err := zp.expectRune('/'); err != nil { + return "", err + } + return zp.readWord('>') +} + +func (zp *zncParser) keyValuePair() (string, string, error) { + k, err := zp.readWord('=') + if err != nil { + return "", "", err + } + v, err := zp.readWord('\n') + return strings.TrimSpace(k), strings.TrimSpace(v), err +} + +func (zp *zncParser) expectRune(expected rune) error { + r, _, err := zp.readRune() + if err != nil { + return err + } else if r != expected { + return fmt.Errorf("expected %q, got %q", expected, r) + } + return nil +} + +func (zp *zncParser) readWord(delim byte) (string, error) { + var sb strings.Builder + for { + b, err := zp.readByte() + if err != nil { + return "", err + } + + if b == delim { + return sb.String(), nil + } + if b == '\n' { + return "", fmt.Errorf("expected %q before newline", delim) + } + + sb.WriteByte(b) + } +} + +type stringSetFlag map[string]bool + +func (v *stringSetFlag) String() string { + return fmt.Sprint(map[string]bool(*v)) +} + +func (v *stringSetFlag) Set(s string) error { + (*v)[s] = true + return nil +} diff --git a/trunk/cmd/suika/main.go b/trunk/cmd/suika/main.go new file mode 100644 index 0000000..9661859 --- /dev/null +++ b/trunk/cmd/suika/main.go @@ -0,0 +1,229 @@ +package main + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/url" + "os" + "os/signal" + "strings" + "sync/atomic" + "syscall" + "time" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" +) + +// TCP keep-alive interval for downstream TCP connections +const downstreamKeepAlive = 1 * time.Hour + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +func bumpOpenedFileLimit() error { + var rlimit syscall.Rlimit + if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err) + } + rlimit.Cur = rlimit.Max + if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil { + return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err) + } + return nil +} + +var ( + configPath string + debug bool + + tlsCert atomic.Value // *tls.Certificate +) + +func loadConfig() (*config.Server, *suika.Config, error) { + var raw *config.Server + if configPath != "" { + var err error + raw, err = config.Load(configPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load config file: %v", err) + } + } else { + raw = config.Defaults() + } + + var motd string + if raw.MOTDPath != "" { + b, err := os.ReadFile(raw.MOTDPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load MOTD: %v", err) + } + motd = strings.TrimSuffix(string(b), "\n") + } + + if raw.TLS != nil { + cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err) + } + tlsCert.Store(&cert) + } + + cfg := &suika.Config{ + Hostname: raw.Hostname, + Title: raw.Title, + LogPath: raw.LogPath, + MaxUserNetworks: raw.MaxUserNetworks, + MultiUpstream: raw.MultiUpstream, + UpstreamUserIPs: raw.UpstreamUserIPs, + MOTD: motd, + } + return raw, cfg, nil +} + +func main() { + var listen []string + flag.Var((*stringSliceFlag)(&listen), "listen", "listening address") + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.BoolVar(&debug, "debug", false, "enable debug logging") + flag.Parse() + + cfg, serverCfg, err := loadConfig() + if err != nil { + log.Fatal(err) + } + + cfg.Listen = append(cfg.Listen, listen...) + if len(cfg.Listen) == 0 { + cfg.Listen = []string{":6667"} + } + + if err := bumpOpenedFileLimit(); err != nil { + log.Printf("failed to bump max number of opened files: %v", err) + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + var tlsCfg *tls.Config + if cfg.TLS != nil { + tlsCfg = &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return tlsCert.Load().(*tls.Certificate), nil + }, + } + } + + srv := suika.NewServer(db) + srv.SetConfig(serverCfg) + srv.Logger = suika.NewLogger(log.Writer(), debug) + + for _, listen := range cfg.Listen { + listen := listen // copy + listenURI := listen + if !strings.Contains(listenURI, ":/") { + // This is a raw domain name, make it an URL with an empty scheme + listenURI = "//" + listenURI + } + u, err := url.Parse(listenURI) + if err != nil { + log.Fatalf("failed to parse listen URI %q: %v", listen, err) + } + + switch u.Scheme { + case "ircs", "": + if tlsCfg == nil { + log.Fatalf("failed to listen on %q: missing TLS configuration", listen) + } + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6697" + } + ircsTLSCfg := tlsCfg.Clone() + ircsTLSCfg.NextProtos = []string{"irc"} + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + l, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start TLS listener on %q: %v", listen, err) + } + ln := tls.NewListener(l, ircsTLSCfg) + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "irc": + host := u.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = host + ":6667" + } + lc := net.ListenConfig{ + KeepAlive: downstreamKeepAlive, + } + ln, err := lc.Listen(context.Background(), "tcp", host) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + case "unix": + ln, err := net.Listen("unix", u.Path) + if err != nil { + log.Fatalf("failed to start listener on %q: %v", listen, err) + } + go func() { + if err := srv.Serve(ln); err != nil { + log.Printf("serving %q: %v", listen, err) + } + }() + default: + log.Fatalf("failed to listen on %q: unsupported scheme", listen) + } + + log.Printf("starting suika version %v\n", suika.FullVersion()) + log.Printf("server listening on %q", listen) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + if err := srv.Start(); err != nil { + log.Fatal(err) + } + + for sig := range sigCh { + switch sig { + case syscall.SIGHUP: + log.Print("reloading configuration") + _, serverCfg, err := loadConfig() + if err != nil { + log.Printf("failed to reloading configuration: %v", err) + } else { + srv.SetConfig(serverCfg) + } + case syscall.SIGINT, syscall.SIGTERM: + log.Print("shutting down server") + srv.Shutdown() + return + } + } +} diff --git a/trunk/cmd/suikadb/main.go b/trunk/cmd/suikadb/main.go new file mode 100644 index 0000000..4d54481 --- /dev/null +++ b/trunk/cmd/suikadb/main.go @@ -0,0 +1,150 @@ +package main + +import ( + "bufio" + "context" + "flag" + "fmt" + "io" + "log" + "os" + + "marisa.chaotic.ninja/suika" + "marisa.chaotic.ninja/suika/config" + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" +) + +const usage = `usage: suikadb [-config path] [options...] + + create-user [-admin] Create a new user + change-password Change password for a user + help Show this help message +` + +func init() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), usage) + } +} + +func main() { + var configPath string + flag.StringVar(&configPath, "config", "", "path to configuration file") + flag.Parse() + + var cfg *config.Server + if configPath != "" { + var err error + cfg, err = config.Load(configPath) + if err != nil { + log.Fatalf("failed to load config file: %v", err) + } + } else { + cfg = config.Defaults() + } + + db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + + ctx := context.Background() + + switch cmd := flag.Arg(0); cmd { + case "create-user": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + fs := flag.NewFlagSet("", flag.ExitOnError) + admin := fs.Bool("admin", false, "make the new user admin") + fs.Parse(flag.Args()[2:]) + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user := suika.User{ + Username: username, + Password: string(hashed), + Admin: *admin, + } + if err := db.StoreUser(ctx, &user); err != nil { + log.Fatalf("failed to create user: %v", err) + } + case "change-password": + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + + user, err := db.GetUser(ctx, username) + if err != nil { + log.Fatalf("failed to get user: %v", err) + } + + password, err := readPassword() + if err != nil { + log.Fatalf("failed to read password: %v", err) + } + + hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost) + if err != nil { + log.Fatalf("failed to hash password: %v", err) + } + + user.Password = string(hashed) + if err := db.StoreUser(ctx, user); err != nil { + log.Fatalf("failed to update password: %v", err) + } + case "version": + fmt.Printf("%v\n", suika.FullVersion()) + default: + flag.Usage() + if cmd != "help" { + os.Exit(1) + } + } +} + +func readPassword() ([]byte, error) { + var password []byte + var err error + fd := int(os.Stdin.Fd()) + + if term.IsTerminal(fd) { + fmt.Printf("Password: ") + password, err = term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + return nil, err + } + fmt.Printf("\n") + } else { + fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n") + // TODO: the buffering messes up repeated calls to readPassword + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, err + } + return nil, io.ErrUnexpectedEOF + } + password = scanner.Bytes() + + if len(password) == 0 { + return nil, fmt.Errorf("zero length password") + } + } + + return password, nil +} diff --git a/trunk/config.in b/trunk/config.in new file mode 100644 index 0000000..0b1ed4d --- /dev/null +++ b/trunk/config.in @@ -0,0 +1,2 @@ +db sqlite3 /var/lib/suika/main.db +log fs /var/lib/suika/logs/ diff --git a/trunk/config/config.go b/trunk/config/config.go new file mode 100644 index 0000000..94d4584 --- /dev/null +++ b/trunk/config/config.go @@ -0,0 +1,142 @@ +package config + +import ( + "fmt" + "net" + "os" + "strconv" + + "git.sr.ht/~emersion/go-scfg" +) + +type TLS struct { + CertPath, KeyPath string +} + +type Server struct { + Listen []string + TLS *TLS + Hostname string + Title string + MOTDPath string + + SQLDriver string + SQLSource string + LogPath string + + MaxUserNetworks int + MultiUpstream bool + UpstreamUserIPs []*net.IPNet +} + +func Defaults() *Server { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + return &Server{ + Hostname: hostname, + SQLDriver: "sqlite3", + SQLSource: "suika.db", + MaxUserNetworks: -1, + MultiUpstream: true, + } +} + +func Load(path string) (*Server, error) { + cfg, err := scfg.Load(path) + if err != nil { + return nil, err + } + return parse(cfg) +} + +func parse(cfg scfg.Block) (*Server, error) { + srv := Defaults() + for _, d := range cfg { + switch d.Name { + case "listen": + var uri string + if err := d.ParseParams(&uri); err != nil { + return nil, err + } + srv.Listen = append(srv.Listen, uri) + case "hostname": + if err := d.ParseParams(&srv.Hostname); err != nil { + return nil, err + } + case "title": + if err := d.ParseParams(&srv.Title); err != nil { + return nil, err + } + case "motd": + if err := d.ParseParams(&srv.MOTDPath); err != nil { + return nil, err + } + case "tls": + tls := &TLS{} + if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil { + return nil, err + } + srv.TLS = tls + case "db": + if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil { + return nil, err + } + case "log": + var driver string + if err := d.ParseParams(&driver, &srv.LogPath); err != nil { + return nil, err + } + if driver != "fs" { + return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver) + } + case "max-user-networks": + var max string + if err := d.ParseParams(&max); err != nil { + return nil, err + } + var err error + if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + case "multi-upstream-mode": + var str string + if err := d.ParseParams(&str); err != nil { + return nil, err + } + v, err := strconv.ParseBool(str) + if err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } + srv.MultiUpstream = v + case "upstream-user-ip": + if len(srv.UpstreamUserIPs) > 0 { + return nil, fmt.Errorf("directive %q: can only be specified once", d.Name) + } + var hasIPv4, hasIPv6 bool + for _, s := range d.Params { + _, n, err := net.ParseCIDR(s) + if err != nil { + return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err) + } + if n.IP.To4() == nil { + if hasIPv6 { + return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name) + } + hasIPv6 = true + } else { + if hasIPv4 { + return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name) + } + hasIPv4 = true + } + srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) + } + default: + return nil, fmt.Errorf("unknown directive %q", d.Name) + } + } + + return srv, nil +} diff --git a/trunk/conn.go b/trunk/conn.go new file mode 100644 index 0000000..bb95e25 --- /dev/null +++ b/trunk/conn.go @@ -0,0 +1,180 @@ +package suika + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "golang.org/x/time/rate" + "gopkg.in/irc.v3" +) + +// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on +// reading and writing IRC messages. +type ircConn interface { + ReadMessage() (*irc.Message, error) + WriteMessage(*irc.Message) error + Close() error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error + RemoteAddr() net.Addr + LocalAddr() net.Addr +} + +func newNetIRCConn(c net.Conn) ircConn { + type netConn net.Conn + return struct { + *irc.Conn + netConn + }{irc.NewConn(c), c} +} + +type connOptions struct { + Logger Logger + RateLimitDelay time.Duration + RateLimitBurst int +} + +type conn struct { + conn ircConn + srv *Server + logger Logger + + lock sync.Mutex + outgoing chan<- *irc.Message + closed bool + closedCh chan struct{} +} + +func newConn(srv *Server, ic ircConn, options *connOptions) *conn { + outgoing := make(chan *irc.Message, 64) + c := &conn{ + conn: ic, + srv: srv, + outgoing: outgoing, + logger: options.Logger, + closedCh: make(chan struct{}), + } + + go func() { + ctx, cancel := c.NewContext(context.Background()) + defer cancel() + + rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst) + for msg := range outgoing { + if err := rl.Wait(ctx); err != nil { + break + } + + c.logger.Debugf("sent: %v", msg) + c.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if err := c.conn.WriteMessage(msg); err != nil { + c.logger.Printf("failed to write message: %v", err) + break + } + } + if err := c.conn.Close(); err != nil && !isErrClosed(err) { + c.logger.Printf("failed to close connection: %v", err) + } else { + c.logger.Debugf("connection closed") + } + // Drain the outgoing channel to prevent SendMessage from blocking + for range outgoing { + // This space is intentionally left blank + } + }() + + c.logger.Debugf("new connection") + return c +} + +func (c *conn) isClosed() bool { + c.lock.Lock() + defer c.lock.Unlock() + return c.closed +} + +// Close closes the connection. It is safe to call from any goroutine. +func (c *conn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return fmt.Errorf("connection already closed") + } + + err := c.conn.Close() + c.closed = true + close(c.outgoing) + close(c.closedCh) + return err +} + +func (c *conn) ReadMessage() (*irc.Message, error) { + msg, err := c.conn.ReadMessage() + if isErrClosed(err) { + return nil, io.EOF + } else if err != nil { + return nil, err + } + + c.logger.Debugf("received: %v", msg) + return msg, nil +} + +// SendMessage queues a new outgoing message. It is safe to call from any +// goroutine. +// +// If the connection is closed before the message is sent, SendMessage silently +// drops the message. +func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return + } + + select { + case c.outgoing <- msg: + // Success + case <-ctx.Done(): + c.logger.Printf("failed to send message: %v", ctx.Err()) + } +} + +func (c *conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// NewContext returns a copy of the parent context with a new Done channel. The +// returned context's Done channel is closed when the connection is closed, +// when the returned cancel function is called, or when the parent context's +// Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(parent) + + go func() { + defer cancel() + + select { + case <-ctx.Done(): + // The parent context has been cancelled, or the caller has called + // cancel() + case <-c.closedCh: + // The connection has been closed + } + }() + + return ctx, cancel +} diff --git a/trunk/contrib/casemap-logs.sh b/trunk/contrib/casemap-logs.sh new file mode 100644 index 0000000..709d581 --- /dev/null +++ b/trunk/contrib/casemap-logs.sh @@ -0,0 +1,42 @@ +#!/bin/sh -eu + +# Converts a log dir to its case-mapped form. +# +# suika needs to be stopped for this script to work properly. The script may +# re-order messages that happened within the same second interval if merging +# two daily log files is necessary. +# +# usage: casemap-logs.sh + +root="$1" + +for net_dir in "$root"/*/*; do + for chan in $(ls "$net_dir"); do + cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')" + if [ "$chan" = "$cm_chan" ]; then + continue + fi + + if ! [ -d "$net_dir/$cm_chan" ]; then + echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + mv "$net_dir/$chan" "$net_dir/$cm_chan" + continue + fi + + echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'" + for day in $(ls "$net_dir/$chan"); do + if ! [ -e "$net_dir/$cm_chan/$day" ]; then + echo >&2 " Moving log file: '$day'" + mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" + continue + fi + + echo >&2 " Merging log file: '$day'" + sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new" + mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day" + rm "$net_dir/$chan/$day" + done + + rmdir "$net_dir/$chan" + done +done diff --git a/trunk/contrib/clients.md b/trunk/contrib/clients.md new file mode 100644 index 0000000..1f8881d --- /dev/null +++ b/trunk/contrib/clients.md @@ -0,0 +1,100 @@ +# Clients + +This page describes how to configure IRC clients to better integrate with soju. + +Also see the [IRCv3 support tables] for a more general list of clients. + +# catgirl + +catgirl doesn't properly implement cap-3.2, so many capabilities will be +disabled. catgirl developers have publicly stated that supporting bouncers such +as soju is a non-goal. + +# [Emacs] + +There are two clients provided with Emacs. They require some setup to work +properly. + +## Erc + +You need to explicitly set the username, which is the defcustom +`erc-email-userid`. + +```elisp +(setq erc-email-userid "/irc.libera.chat") ;; Example with Libera.Chat +(defun run-erc () + (interactive) + (erc-tls :server "" + :port 6697 + :nick "" + :password "")) +``` + +Then run `M-x run-erc`. + +## Rcirc + +The only thing needed here is the general config: + +```elisp +(setq rcirc-server-alist + '(("" + :port 6697 + :encryption tls + :nick "" + :user-name "/irc.libera.chat" ;; Example with Libera.Chat + :password ""))) +``` + +Then run `M-x irc`. + +# [gamja] + +gamja has been designed together with soju, so should have excellent +integration. gamja supports many IRCv3 features including chat history. +gamja also provides UI to manage soju networks via the +`soju.im/bouncer-networks` extension. + +# [goguma] + +Much like gamja, goguma has been designed together with soju, so should have +excellent integration. goguma supports many IRCv3 features including chat +history. goguma should seamlessly connect to all networks configured in soju via +the `soju.im/bouncer-networks` extension. + +# [Hexchat] + +Hexchat has support for a small set of IRCv3 capabilities. To prevent +automatically reconnecting to channels parted from soju, and prevent buffering +outgoing messages: + + /set irc_reconnect_rejoin off + /set net_throttle off + +# [senpai] + +senpai is being developed with soju in mind, so should have excellent +integration. senpai supports many IRCv3 features including chat history. + +# [Weechat] + +A [Weechat script] is available to provide better integration with soju. +The script will automatically connect to all of your networks once a +single connection to soju is set up in Weechat. + +On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them: + + /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names + /save + /reconnect -all + +See `/help cap` for more information. + +[IRCv3 support tables]: https://ircv3.net/software/clients +[gamja]: https://sr.ht/~emersion/gamja/ +[goguma]: https://sr.ht/~emersion/goguma/ +[senpai]: https://sr.ht/~taiite/senpai/ +[Weechat]: https://weechat.org/ +[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py +[Hexchat]: https://hexchat.github.io/ +[Emacs]: https://www.gnu.org/software/emacs/ diff --git a/trunk/db.go b/trunk/db.go new file mode 100644 index 0000000..95ac964 --- /dev/null +++ b/trunk/db.go @@ -0,0 +1,184 @@ +package suika + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" +) + +type Database interface { + Close() error + Stats(ctx context.Context) (*DatabaseStats, error) + + ListUsers(ctx context.Context) ([]User, error) + GetUser(ctx context.Context, username string) (*User, error) + StoreUser(ctx context.Context, user *User) error + DeleteUser(ctx context.Context, id int64) error + + ListNetworks(ctx context.Context, userID int64) ([]Network, error) + StoreNetwork(ctx context.Context, userID int64, network *Network) error + DeleteNetwork(ctx context.Context, id int64) error + ListChannels(ctx context.Context, networkID int64) ([]Channel, error) + StoreChannel(ctx context.Context, networKID int64, ch *Channel) error + DeleteChannel(ctx context.Context, id int64) error + + ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) + StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error + + GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) + StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error +} + +func OpenDB(driver, source string) (Database, error) { + switch driver { + case "sqlite3": + return OpenSqliteDB(source) + case "postgres": + return OpenPostgresDB(source) + default: + return nil, fmt.Errorf("unsupported database driver: %q", driver) + } +} + +type DatabaseStats struct { + Users int64 + Networks int64 + Channels int64 +} + +type User struct { + ID int64 + Username string + Password string // hashed + Realname string + Admin bool +} + +type SASL struct { + Mechanism string + + Plain struct { + Username string + Password string + } + + // TLS client certificate authentication. + External struct { + // X.509 certificate in DER form. + CertBlob []byte + // PKCS#8 private key in DER form. + PrivKeyBlob []byte + } +} + +type Network struct { + ID int64 + Name string + Addr string + Nick string + Username string + Realname string + Pass string + ConnectCommands []string + SASL SASL + Enabled bool +} + +func (net *Network) GetName() string { + if net.Name != "" { + return net.Name + } + return net.Addr +} + +func (net *Network) URL() (*url.URL, error) { + s := net.Addr + if !strings.Contains(s, "://") { + // This is a raw domain name, make it an URL with the default scheme + s = "ircs://" + s + } + + u, err := url.Parse(s) + if err != nil { + return nil, fmt.Errorf("failed to parse upstream server URL: %v", err) + } + + return u, nil +} + +func GetNick(user *User, net *Network) string { + if net.Nick != "" { + return net.Nick + } + return user.Username +} + +func GetUsername(user *User, net *Network) string { + if net.Username != "" { + return net.Username + } + return GetNick(user, net) +} + +func GetRealname(user *User, net *Network) string { + if net.Realname != "" { + return net.Realname + } + if user.Realname != "" { + return user.Realname + } + return GetNick(user, net) +} + +type MessageFilter int + +const ( + // TODO: use customizable user defaults for FilterDefault + FilterDefault MessageFilter = iota + FilterNone + FilterHighlight + FilterMessage +) + +func parseFilter(filter string) (MessageFilter, error) { + switch filter { + case "default": + return FilterDefault, nil + case "none": + return FilterNone, nil + case "highlight": + return FilterHighlight, nil + case "message": + return FilterMessage, nil + } + return 0, fmt.Errorf("unknown filter: %q", filter) +} + +type Channel struct { + ID int64 + Name string + Key string + + Detached bool + DetachedInternalMsgID string + + RelayDetached MessageFilter + ReattachOn MessageFilter + DetachAfter time.Duration + DetachOn MessageFilter +} + +type DeliveryReceipt struct { + ID int64 + Target string // channel or nick + Client string + InternalMsgID string +} + +type ReadReceipt struct { + ID int64 + Target string // channel or nick + Timestamp time.Time +} diff --git a/trunk/db_postgres.go b/trunk/db_postgres.go new file mode 100644 index 0000000..0849338 --- /dev/null +++ b/trunk/db_postgres.go @@ -0,0 +1,494 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "errors" + "fmt" + "math" + "strings" + "time" + + _ "github.com/lib/pq" +) + +const postgresQueryTimeout = 5 * time.Second + +const postgresConfigSchema = ` +CREATE TABLE IF NOT EXISTS "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); +` +//go:embed suika_psql_schema.sql +var postgresSchema string + +var postgresMigrations = []string{ + "", // migration #0 is reserved for schema initialization + `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`, + ` + CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + ALTER TABLE "Network" + ALTER COLUMN sasl_mechanism + TYPE sasl_mechanism + USING sasl_mechanism::sasl_mechanism; + `, + ` + CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) + ); + `, +} + +type PostgresDB struct { + db *sql.DB +} + +func OpenPostgresDB(source string) (Database, error) { + sqlPostgresDB, err := sql.Open("postgres", source) + if err != nil { + return nil, err + } + + db := &PostgresDB{db: sqlPostgresDB} + if err := db.upgrade(); err != nil { + sqlPostgresDB.Close() + return nil, err + } + + return db, nil +} + +func (db *PostgresDB) upgrade() error { + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.Exec(postgresConfigSchema); err != nil { + return fmt.Errorf("failed to create Config table: %s", err) + } + + var version int + err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to query schema version: %s", err) + } + + if version == len(postgresMigrations) { + return nil + } + if version > len(postgresMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version) + } + + if version == 0 { + if _, err := tx.Exec(postgresSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %s", err) + } + } else { + for i := version; i < len(postgresMigrations); i++ { + if _, err := tx.Exec(postgresMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1) + ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations)) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *PostgresDB) Close() error { + return db.db.Close() +} + +func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM "User") AS users, + (SELECT COUNT(*) FROM "Network") AS networks, + (SELECT COUNT(*) FROM "Channel") AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT id, username, password, admin, realname FROM "User"`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + `SELECT id, password, admin, realname FROM "User" WHERE username = $1`, + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + password := toNullString(user.Password) + realname := toNullString(user.Realname) + + var err error + if user.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "User" (username, password, admin, realname) + VALUES ($1, $2, $3, $4) + RETURNING id`, + user.Username, password, user.Admin, realname).Scan(&user.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "User" + SET password = $1, admin = $2, realname = $3 + WHERE id = $4`, + password, user.Admin, realname, user.ID) + } + return err +} + +func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, + sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled + FROM "Network" + WHERE "user" = $1`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + netName := toNullString(network.Name) + nick := toNullString(network.Nick) + netUsername := toNullString(network.Username) + realname := toNullString(network.Realname) + pass := toNullString(network.Pass) + connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + var err error + if network.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, + sasl_external_key, enabled) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id`, + userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Network" + SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, + connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, + sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13, + enabled = $14 + WHERE id = $1`, + network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, + saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, + network.SASL.External.PrivKeyBlob, network.Enabled) + } + return err +} + +func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, + detach_on + FROM "Channel" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + key := toNullString(ch.Key) + detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds())) + + var err error + if ch.ID == 0 { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, + detach_after, detach_on) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING id`, + networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID) + } else { + _, err = db.db.ExecContext(ctx, ` + UPDATE "Channel" + SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5, + relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9 + WHERE id = $1`, + ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID), + ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn) + } + return err +} + +func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id) + return err +} + +func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM "DeliveryReceipt" + WHERE network = $1`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, + `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`, + networkID, client) + if err != nil { + return err + } + + stmt, err := tx.PrepareContext(ctx, ` + INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid) + VALUES ($1, $2, $3, $4) + RETURNING id`) + if err != nil { + return err + } + defer stmt.Close() + + for i := range receipts { + rcpt := &receipts[i] + err := stmt. + QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID). + Scan(&rcpt.ID) + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, + `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`, + networkID, name) + if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return receipt, nil +} + +func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE "ReadReceipt" + SET timestamp = $1 + WHERE id = $2`, + receipt.Timestamp, receipt.ID) + } else { + err = db.db.QueryRowContext(ctx, ` + INSERT INTO "ReadReceipt" (network, target, timestamp) + VALUES ($1, $2, $3) + RETURNING id`, + networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID) + } + return err +} diff --git a/trunk/db_postgres_test.go b/trunk/db_postgres_test.go new file mode 100644 index 0000000..7692f31 --- /dev/null +++ b/trunk/db_postgres_test.go @@ -0,0 +1,104 @@ +package suika + +import ( + "database/sql" + "os" + "testing" +) + +// PostgreSQL version 0 schema. DO NOT EDIT. +const postgresV0Schema = ` +CREATE TABLE "Config" ( + id SMALLINT PRIMARY KEY, + version INTEGER NOT NULL, + CHECK(id = 1) +); + +INSERT INTO "Config" (id, version) VALUES (1, 1); + +CREATE TABLE "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TABLE "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA DEFAULT NULL, + sasl_external_key BYTEA DEFAULT NULL, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); + +CREATE TABLE "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); + +CREATE TABLE "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +` + +func openTempPostgresDB(t *testing.T) *sql.DB { + source, ok := os.LookupEnv("SOJU_TEST_POSTGRES") + if !ok { + t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests") + } + + db, err := sql.Open("postgres", source) + if err != nil { + t.Fatalf("failed to connect to PostgreSQL: %v", err) + } + + // Store all tables in a temporary schema which will be dropped when the + // connection to PostgreSQL is closed. + db.SetMaxOpenConns(1) + if _, err := db.Exec("SET search_path TO pg_temp"); err != nil { + t.Fatalf("failed to set PostgreSQL search_path: %v", err) + } + + return db +} + +func TestPostgresMigrations(t *testing.T) { + sqlDB := openTempPostgresDB(t) + if _, err := sqlDB.Exec(postgresV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &PostgresDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("PostgresDB.Upgrade() failed: %v", err) + } +} diff --git a/trunk/db_sqlite.go b/trunk/db_sqlite.go new file mode 100644 index 0000000..7ebebb0 --- /dev/null +++ b/trunk/db_sqlite.go @@ -0,0 +1,755 @@ +package suika + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "math" + "strings" + "sync" + "time" + + _ "modernc.org/sqlite" +) + +const sqliteQueryTimeout = 5 * time.Second + +//go:embed suika_sqlite_schema.sql +var sqliteSchema string + +var sqliteMigrations = []string{ + "", // migration #0 is reserved for schema initialization + "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)", + "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL", + "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL", + "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0", + ` + CREATE TABLE IF NOT EXISTS UserNew ( + id INTEGER PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin INTEGER NOT NULL DEFAULT 0 + ); + INSERT INTO UserNew SELECT rowid, username, password, admin FROM User; + DROP TABLE User; + ALTER TABLE UserNew RENAME TO User; + `, + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user INTEGER NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BLOB DEFAULT NULL, + sasl_external_key BLOB DEFAULT NULL, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT Network.id, name, User.id as user, addr, nick, + Network.username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key + FROM Network + JOIN User ON Network.user = User.username; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0; + ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0; + `, + ` + CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target VARCHAR(255) NOT NULL, + client VARCHAR(255), + internal_msgid VARCHAR(255) NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) + ); + `, + "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)", + "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", + "ALTER TABLE User ADD COLUMN realname VARCHAR(255)", + ` + CREATE TABLE IF NOT EXISTS NetworkNew ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT id, name, user, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, + enabled + FROM Network; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, + ` + CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) + ); + `, +} + +type SqliteDB struct { + lock sync.RWMutex + db *sql.DB +} + +func OpenSqliteDB(source string) (Database, error) { + sqlSqliteDB, err := sql.Open("sqlite", source) + if err != nil { + return nil, err + } + + db := &SqliteDB{db: sqlSqliteDB} + if err := db.upgrade(); err != nil { + sqlSqliteDB.Close() + return nil, err + } + + return db, nil +} + +func (db *SqliteDB) Close() error { + db.lock.Lock() + defer db.lock.Unlock() + return db.db.Close() +} + +func (db *SqliteDB) upgrade() error { + db.lock.Lock() + defer db.lock.Unlock() + + var version int + if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + return fmt.Errorf("failed to query schema version: %v", err) + } + + if version == len(sqliteMigrations) { + return nil + } else if version > len(sqliteMigrations) { + return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version) + } + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if version == 0 { + if _, err := tx.Exec(sqliteSchema); err != nil { + return fmt.Errorf("failed to initialize schema: %v", err) + } + } else { + for i := version; i < len(sqliteMigrations); i++ { + if _, err := tx.Exec(sqliteMigrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + // For some reason prepared statements don't work here + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations))) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var stats DatabaseStats + row := db.db.QueryRowContext(ctx, `SELECT + (SELECT COUNT(*) FROM User) AS users, + (SELECT COUNT(*) FROM Network) AS networks, + (SELECT COUNT(*) FROM Channel) AS channels`) + if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil { + return nil, err + } + + return &stats, nil +} + +func toNullString(s string) sql.NullString { + return sql.NullString{ + String: s, + Valid: s != "", + } +} + +func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + "SELECT id, username, password, admin, realname FROM User") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var user User + var password, realname sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return users, nil +} + +func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + user := &User{Username: username} + + var password, realname sql.NullString + row := db.db.QueryRowContext(ctx, + "SELECT id, password, admin, realname FROM User WHERE username = ?", + username) + if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil { + return nil, err + } + user.Password = password.String + user.Realname = realname.String + return user, nil +} + +func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("username", user.Username), + sql.Named("password", toNullString(user.Password)), + sql.Named("admin", user.Admin), + sql.Named("realname", toNullString(user.Realname)), + } + + var err error + if user.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE User SET password = :password, admin = :admin, + realname = :realname WHERE username = :username`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + User(username, password, admin, realname) + VALUES (:username, :password, :admin, :realname)`, + args...) + if err != nil { + return err + } + user.ID, err = res.LastInsertId() + } + + return err +} + +func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt + WHERE id IN ( + SELECT DeliveryReceipt.id + FROM DeliveryReceipt + JOIN Network ON DeliveryReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt + WHERE id IN ( + SELECT ReadReceipt.id + FROM ReadReceipt + JOIN Network ON ReadReceipt.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, `DELETE FROM Channel + WHERE id IN ( + SELECT Channel.id + FROM Channel + JOIN Network ON Channel.network = Network.id + WHERE Network.user = ? + )`, id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key, enabled + FROM Network + WHERE user = ?`, + userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var networks []Network + for rows.Next() { + var net Network + var name, nick, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, + &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, + &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled) + if err != nil { + return nil, err + } + net.Name = name.String + net.Nick = nick.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") + } + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String + networks = append(networks, net) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return networks, nil +} + +func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString + if network.SASL.Mechanism != "" { + saslMechanism = toNullString(network.SASL.Mechanism) + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) + network.SASL.External.CertBlob = nil + network.SASL.External.PrivKeyBlob = nil + case "EXTERNAL": + // keep saslPlain* nil + default: + return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism) + } + } + + args := []interface{}{ + sql.Named("name", toNullString(network.Name)), + sql.Named("addr", network.Addr), + sql.Named("nick", toNullString(network.Nick)), + sql.Named("username", toNullString(network.Username)), + sql.Named("realname", toNullString(network.Realname)), + sql.Named("pass", toNullString(network.Pass)), + sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))), + sql.Named("sasl_mechanism", saslMechanism), + sql.Named("sasl_plain_username", saslPlainUsername), + sql.Named("sasl_plain_password", saslPlainPassword), + sql.Named("sasl_external_cert", network.SASL.External.CertBlob), + sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob), + sql.Named("enabled", network.Enabled), + + sql.Named("id", network.ID), // only for UPDATE + sql.Named("user", userID), // only for INSERT + } + + var err error + if network.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE Network + SET name = :name, addr = :addr, nick = :nick, username = :username, + realname = :realname, pass = :pass, connect_commands = :connect_commands, + sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password, + sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key, + enabled = :enabled + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO Network(user, name, addr, nick, username, realname, pass, + connect_commands, sasl_mechanism, sasl_plain_username, + sasl_plain_password, sasl_external_cert, sasl_external_key, enabled) + VALUES (:user, :name, :addr, :nick, :username, :realname, :pass, + :connect_commands, :sasl_mechanism, :sasl_plain_username, + :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`, + args...) + if err != nil { + return err + } + network.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id) + if err != nil { + return err + } + + _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id) + if err != nil { + return err + } + + return tx.Commit() +} + +func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, `SELECT + id, name, key, detached, detached_internal_msgid, + relay_detached, reattach_on, detach_after, detach_on + FROM Channel + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []Channel + for rows.Next() { + var ch Channel + var key, detachedInternalMsgID sql.NullString + var detachAfter int64 + if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil { + return nil, err + } + ch.Key = key.String + ch.DetachedInternalMsgID = detachedInternalMsgID.String + ch.DetachAfter = time.Duration(detachAfter) * time.Second + channels = append(channels, ch) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("network", networkID), + sql.Named("name", ch.Name), + sql.Named("key", toNullString(ch.Key)), + sql.Named("detached", ch.Detached), + sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)), + sql.Named("relay_detached", ch.RelayDetached), + sql.Named("reattach_on", ch.ReattachOn), + sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))), + sql.Named("detach_on", ch.DetachOn), + + sql.Named("id", ch.ID), // only for UPDATE + } + + var err error + if ch.ID != 0 { + _, err = db.db.ExecContext(ctx, `UPDATE Channel + SET network = :network, name = :name, key = :key, detached = :detached, + detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached, + reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on + WHERE id = :id`, args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on) + VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...) + if err != nil { + return err + } + ch.ID, err = res.LastInsertId() + } + return err +} + +func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id) + return err +} + +func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, ` + SELECT id, target, client, internal_msgid + FROM DeliveryReceipt + WHERE network = ?`, networkID) + if err != nil { + return nil, err + } + defer rows.Close() + + var receipts []DeliveryReceipt + for rows.Next() { + var rcpt DeliveryReceipt + var client sql.NullString + if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil { + return nil, err + } + rcpt.Client = client.String + receipts = append(receipts, rcpt) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return receipts, nil +} + +func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?", + networkID, toNullString(client)) + if err != nil { + return err + } + + for i := range receipts { + rcpt := &receipts[i] + + res, err := tx.ExecContext(ctx, ` + INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) + VALUES (:network, :target, :client, :internal_msgid)`, + sql.Named("network", networkID), + sql.Named("target", rcpt.Target), + sql.Named("client", toNullString(client)), + sql.Named("internal_msgid", rcpt.InternalMsgID)) + if err != nil { + return err + } + rcpt.ID, err = res.LastInsertId() + if err != nil { + return err + } + } + + return tx.Commit() +} + +func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + receipt := &ReadReceipt{ + Target: name, + } + + row := db.db.QueryRowContext(ctx, ` + SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`, + sql.Named("network", networkID), + sql.Named("target", name), + ) + var timestamp string + if err := row.Scan(&receipt.ID, ×tamp); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + if t, err := time.Parse(serverTimeLayout, timestamp); err != nil { + return nil, err + } else { + receipt.Timestamp = t + } + return receipt, nil +} + +func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error { + db.lock.Lock() + defer db.lock.Unlock() + + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + args := []interface{}{ + sql.Named("id", receipt.ID), + sql.Named("timestamp", formatServerTime(receipt.Timestamp)), + sql.Named("network", networkID), + sql.Named("target", receipt.Target), + } + + var err error + if receipt.ID != 0 { + _, err = db.db.ExecContext(ctx, ` + UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`, + args...) + } else { + var res sql.Result + res, err = db.db.ExecContext(ctx, ` + INSERT INTO + ReadReceipt(network, target, timestamp) + VALUES (:network, :target, :timestamp)`, + args...) + if err != nil { + return err + } + receipt.ID, err = res.LastInsertId() + } + + return err +} diff --git a/trunk/db_sqlite_test.go b/trunk/db_sqlite_test.go new file mode 100644 index 0000000..a7c9794 --- /dev/null +++ b/trunk/db_sqlite_test.go @@ -0,0 +1,59 @@ +package suika + +import ( + "database/sql" + "testing" +) + +// SQLite version 0 schema. DO NOT EDIT. +const sqliteV0Schema = ` +CREATE TABLE User ( + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255) +); + +CREATE TABLE Network ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user VARCHAR(255) NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); + +CREATE TABLE Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +PRAGMA user_version = 1; +` + +func TestSqliteMigrations(t *testing.T) { + sqlDB, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + + if _, err := sqlDB.Exec(sqliteV0Schema); err != nil { + t.Fatalf("DB.Exec() failed for v0 schema: %v", err) + } + + db := &SqliteDB{db: sqlDB} + defer db.Close() + + if err := db.upgrade(); err != nil { + t.Fatalf("SqliteDB.Upgrade() failed: %v", err) + } +} diff --git a/trunk/doc.go b/trunk/doc.go new file mode 100644 index 0000000..3f7af51 --- /dev/null +++ b/trunk/doc.go @@ -0,0 +1,20 @@ +// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go. +// +// # Copyright (C) 2020 The soju Contributors +// # Copyright (C) 2023-present Izuru Yakumo et al. +// +// suika is covered by the AGPLv3 license: +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . +package suika diff --git a/trunk/doc/suika-config.5 b/trunk/doc/suika-config.5 new file mode 100644 index 0000000..ac0e9ef --- /dev/null +++ b/trunk/doc/suika-config.5 @@ -0,0 +1,70 @@ +.Dd $Mdocdate$ +.Dt SUIKA-CONFIG 5 +.Os +.Sh NAME +.Nm suika-config +.Nd Configuration file for the IRC bouncer +.Sh SYNOPSIS +.Bk -words +listen ircs:// +.Pp +tls cert.pem key.pem +.Pp +hostname example.org +.Ek +.Sh DESCRIPTION +This document describes the format of the configuration +file used by +.Xr suika 1 +.Sh OPTIONS +.Bl -tag -width Ds +.It listen Ar uri +With this you can control on what +ports/protocols +.Xr suika 1 +listens on, it supports +irc (cleartext IRC), ircs (IRC with TLS), and unix +(IRC over Unix domain sockets) +.It hostname Ar hostname +Server hostname, if unset, the system one is used. +.It title Ar title +Server title, this will be sent as the ISUPPORT NETWORK value when +clients don't select a specific network. +.It tls Ar cert Ar key +Enable TLS support, the certificate and key files must be +PEM-encoded. +.It db Ar driver Ar path +Set the database driver for user, network and channel storage. +By default a SQLite 3 database is opened in +.Pa ./suika.db +Supported drivers are sqlite and postgres, the former +expects a path to the database file, and the latter +a space-separated list of key=value parameters, +e.g. host=localhost dbname=suika +.It log fs Ar path +Path to the bouncer logs directory, or empty to disable +logging. +By default, logging is disabled. +.It max-user-networks Ar limit +Maximum number of networks per user, by default +there is no limit. +.It motd Ar path +Path to the MOTD file, its contents are sent to clients +which aren't bound to a particular network. +By default, no MOTD is sent. +.It multi-upstream-mode Ar bool +Globally enable or disable multi-upstream mode. +By default, it is enabled. +.It upstream-user-ip Ar cidr +Enable per-user-IP addresses. +One IPv4 and/or one IPv6 range can be specified in CIDR notation. +One IP address per range will be assigned to each user as the +source address when connecting to an upstream network. +This can be useful to avoid having the whole bouncer banned from +an upstream network because of one malicious user. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/trunk/doc/suika-znc-import.1 b/trunk/doc/suika-znc-import.1 new file mode 100644 index 0000000..7acc618 --- /dev/null +++ b/trunk/doc/suika-znc-import.1 @@ -0,0 +1,34 @@ +.Dd $Mdocdate$ +.Dt SUIKA-ZNC-IMPORT 1 +.Os +.Sh NAME +.Nm suika-znc-import +.Nd Migration utility for moving from ZNC +.Sh SYNOPSIS +.Nm +.Op Fl config Ar suika config file +.Op Fl user Ar username +.Op Fl network Ar name +.Sh DESCRIPTION +Imports configuration from a ZNC file. +Users and networks are merged if they already exist in the +.Xr suika 1 +database. +ZNC settings overwrite existing +.Xr suika 1 +settings +.Sh OPTIONS +.Bl -tag -width Ds +.It config Ar suika config file +Path to +.Xr suika-config 5 +.It user Ar username +Limit import to username, may be specified multiple times. +It network Ar name +Limit import to network, may be specified multiple times. +.El +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/trunk/doc/suika.1 b/trunk/doc/suika.1 new file mode 100644 index 0000000..afe6607 --- /dev/null +++ b/trunk/doc/suika.1 @@ -0,0 +1,68 @@ +.Dd $Mdocdate$ +.Dt SUIKA 1 +.Os +.Sh NAME +.Nm suika +.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project +.Sh SYNOPSIS +.Nm +.Op Fl config Ar path +.Op Fl debug +.Op Fl listen Ar uri +.Sh DESCRIPTION +.Nm +is an user-friendly IRC bouncer, it connects to upstream +IRC servers on behalf of the user to provide extra features. +.Bl -tag -width 6n +.It Multiple separate users sharing the same bouncer +.It Clients connecting to multiple upstream servers (via a single connection) +.It Sending the backlog with per-client buffers +.El +.Pp +When joining a channel, the channel will be saved +and automatically joined on the next connection. +When registering or authenticating with NickServ, the credentials will be saved +and automatically used on the next connection if the server supports SASL. +When parting a channel with the reason "detach", the channel will be +detached instead of being left. +When all clients are disconnected from the bouncer, +the user is automatically marked as away. +.Pp +.Nm +supports two connection modes: +.Bl -tag -width 6n +.It Single upstream mode +One downstream connection maps to one upstream connection +.Pp +To enable this mode, connect to the bouncer +with the username "/". +.Pp +If the bouncer isn't connected to the upstream server, +it will get automatically added. +.Pp +Then channels can be joined and parted as if +you were directly connected to the upstream server. +.It Multiple upstream mode +One downstream connection maps to multiple upstream connections. +Channels and nicks are suffixed with the network name. +To join a channel, you need to use the suffix too: /join #channel/network. +Same applies to messages sent to users. +.El +.Pp +For per-client history to work, clients need to indicate their name. +This can be done by adding a "@" suffix to the username. +.Pp +.Nm +will reload the configuration file, the TLS certificate/key and +the MOTD file when it receives the HUP signal. +The configuration options listen, db and log cannot be reloaded. +.Pp +Administrators can broadcast a message to all bouncer users via +/notice $ , or via /notice $ in multi-upstream mode. +All currently connected bouncer users will receive the message +from the special BouncerServ service. +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/trunk/doc/suikadb.1 b/trunk/doc/suikadb.1 new file mode 100644 index 0000000..5ff4769 --- /dev/null +++ b/trunk/doc/suikadb.1 @@ -0,0 +1,16 @@ +.Dd $Mdocdate$ +.Dt SUIKADB 1 +.Os +.Sh NAME +.Nm suikadb +.Nd Basic user manipulation for +.Xr suika 1 +.Sh SYNOPSIS +.Nm +.Op create-user +.Op change-password +.Sh AUTHORS +.An Simon Ser Aq Mt contact@emersion.fr +.An The soju Contributors +.Sh MAINTAINERS +.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja diff --git a/trunk/downstream.go b/trunk/downstream.go new file mode 100644 index 0000000..3b9adac --- /dev/null +++ b/trunk/downstream.go @@ -0,0 +1,3047 @@ +package suika + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +type ircError struct { + Message *irc.Message +} + +func (err ircError) Error() string { + return err.Message.String() +} + +func newUnknownCommandError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{ + "*", + cmd, + "Unknown command", + }, + }} +} + +func newNeedMoreParamsError(cmd string) ircError { + return ircError{&irc.Message{ + Command: irc.ERR_NEEDMOREPARAMS, + Params: []string{ + "*", + cmd, + "Not enough parameters", + }, + }} +} + +func newChatHistoryError(subcommand string, target string) ircError { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"}, + }} +} + +// authError is an authentication error. +type authError struct { + // Internal error cause. This will not be revealed to the user. + err error + // Error cause which can safely be sent to the user without compromising + // security. + reason string +} + +func (err *authError) Error() string { + return err.err.Error() +} + +func (err *authError) Unwrap() error { + return err.err +} + +// authErrorReason returns the user-friendly reason of an authentication +// failure. +func authErrorReason(err error) string { + if authErr, ok := err.(*authError); ok { + return authErr.reason + } else { + return "Authentication failed" + } +} + +func newInvalidUsernameOrPasswordError(err error) error { + return &authError{ + err: err, + reason: "Invalid username or password", + } +} + +func parseBouncerNetID(subcommand, s string) (int64, error) { + id, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"}, + }} + } + return id, nil +} + +func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) { + u, err := network.URL() + if err != nil { + return + } + + hasHostPort := true + switch u.Scheme { + case "ircs": + attrs["tls"] = irc.TagValue("1") + case "irc": + attrs["tls"] = irc.TagValue("0") + default: // e.g. unix:// + hasHostPort = false + } + if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort { + attrs["host"] = irc.TagValue(host) + attrs["port"] = irc.TagValue(port) + } else if hasHostPort { + attrs["host"] = irc.TagValue(u.Host) + } +} + +func getNetworkAttrs(network *network) irc.Tags { + state := "disconnected" + if uc := network.conn; uc != nil { + state = "connected" + } + + attrs := irc.Tags{ + "name": irc.TagValue(network.GetName()), + "state": irc.TagValue(state), + "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)), + } + + if network.Username != "" { + attrs["username"] = irc.TagValue(network.Username) + } + if realname := GetRealname(&network.user.User, &network.Network); realname != "" { + attrs["realname"] = irc.TagValue(realname) + } + + fillNetworkAddrAttrs(attrs, &network.Network) + + return attrs +} + +func networkAddrFromAttrs(attrs irc.Tags) string { + host, ok := attrs.GetTag("host") + if !ok { + return "" + } + + addr := host + if port, ok := attrs.GetTag("port"); ok { + addr += ":" + port + } + + if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" { + addr = "irc://" + tlsStr + } + + return addr +} + +func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error { + addrAttrs := irc.Tags{} + fillNetworkAddrAttrs(addrAttrs, record) + + updateAddr := false + for k, v := range attrs { + s := string(v) + switch k { + case "host", "port", "tls": + updateAddr = true + addrAttrs[k] = v + case "name": + record.Name = s + case "nickname": + record.Nick = s + case "username": + record.Username = s + case "realname": + record.Realname = s + case "pass": + record.Pass = s + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"}, + }} + } + } + + if updateAddr { + record.Addr = networkAddrFromAttrs(addrAttrs) + if record.Addr == "" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"}, + }} + } + } + + return nil +} + +// illegalNickChars is the list of characters forbidden in a nickname. +// +// ' ' and ':' break the IRC message wire format +// '@' and '!' break prefixes +// '*' breaks masks and is the reserved nickname for registration +// '?' breaks masks +// '$' breaks server masks in PRIVMSG/NOTICE +// ',' breaks lists +// '.' is reserved for server names +const illegalNickChars = " :@!*?$,." + +// permanentDownstreamCaps is the list of always-supported downstream +// capabilities. +var permanentDownstreamCaps = map[string]string{ + "batch": "", + "cap-notify": "", + "echo-message": "", + "invite-notify": "", + "message-tags": "", + "server-time": "", + "setname": "", + + "soju.im/bouncer-networks": "", + "soju.im/bouncer-networks-notify": "", + "soju.im/read": "", +} + +// needAllDownstreamCaps is the list of downstream capabilities that +// require support from all upstreams to be enabled +var needAllDownstreamCaps = map[string]string{ + "account-notify": "", + "account-tag": "", + "away-notify": "", + "extended-join": "", + "multi-prefix": "", + + "draft/extended-monitor": "", +} + +// passthroughIsupport is the set of ISUPPORT tokens that are directly passed +// through from the upstream server to downstream clients. +// +// This is only effective in single-upstream mode. +var passthroughIsupport = map[string]bool{ + "AWAYLEN": true, + "BOT": true, + "CHANLIMIT": true, + "CHANMODES": true, + "CHANNELLEN": true, + "CHANTYPES": true, + "CLIENTTAGDENY": true, + "ELIST": true, + "EXCEPTS": true, + "EXTBAN": true, + "HOSTLEN": true, + "INVEX": true, + "KICKLEN": true, + "MAXLIST": true, + "MAXTARGETS": true, + "MODES": true, + "MONITOR": true, + "NAMELEN": true, + "NETWORK": true, + "NICKLEN": true, + "PREFIX": true, + "SAFELIST": true, + "TARGMAX": true, + "TOPICLEN": true, + "USERLEN": true, + "UTF8ONLY": true, + "WHOX": true, +} + +type downstreamSASL struct { + server sasl.Server + plainUsername, plainPassword string + pendingResp bytes.Buffer +} + +type downstreamConn struct { + conn + + id uint64 + + registered bool + user *user + nick string + nickCM string + rawUsername string + networkName string + clientName string + realname string + hostname string + account string // RPL_LOGGEDIN/OUT state + password string // empty after authentication + network *network // can be nil + isMultiUpstream bool + + negotiatingCaps bool + capVersion int + supportedCaps map[string]string + caps map[string]bool + sasl *downstreamSASL + + lastBatchRef uint64 + + monitored casemapMap +} + +func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { + remoteAddr := ic.RemoteAddr().String() + logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} + options := connOptions{Logger: logger} + dc := &downstreamConn{ + conn: *newConn(srv, ic, &options), + id: id, + nick: "*", + nickCM: "*", + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + monitored: newCasemapMap(0), + } + dc.hostname = remoteAddr + if host, _, err := net.SplitHostPort(dc.hostname); err == nil { + dc.hostname = host + } + for k, v := range permanentDownstreamCaps { + dc.supportedCaps[k] = v + } + dc.supportedCaps["sasl"] = "PLAIN" + // TODO: this is racy, we should only enable chathistory after + // authentication and then check that user.msgStore implements + // chatHistoryMessageStore + if srv.Config().LogPath != "" { + dc.supportedCaps["draft/chathistory"] = "" + } + return dc +} + +func (dc *downstreamConn) prefix() *irc.Prefix { + return &irc.Prefix{ + Name: dc.nick, + User: dc.user.Username, + Host: dc.hostname, + } +} + +func (dc *downstreamConn) forEachNetwork(f func(*network)) { + if dc.network != nil { + f(dc.network) + } else if dc.isMultiUpstream { + for _, network := range dc.user.networks { + f(network) + } + } +} + +func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + dc.user.forEachUpstream(func(uc *upstreamConn) { + if dc.network != nil && uc.network != dc.network { + return + } + f(uc) + }) +} + +// upstream returns the upstream connection, if any. If there are zero or if +// there are multiple upstream connections, it returns nil. +func (dc *downstreamConn) upstream() *upstreamConn { + if dc.network == nil { + return nil + } + return dc.network.conn +} + +func isOurNick(net *network, nick string) bool { + // TODO: this doesn't account for nick changes + if net.conn != nil { + return net.casemap(nick) == net.conn.nickCM + } + // We're not currently connected to the upstream connection, so we don't + // know whether this name is our nickname. Best-effort: use the network's + // configured nickname and hope it was the one being used when we were + // connected. + return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network)) +} + +// marshalEntity converts an upstream entity name (ie. channel or nick) into a +// downstream entity name. +// +// This involves adding a "/" suffix if the entity isn't the current +// user. +func (dc *downstreamConn) marshalEntity(net *network, name string) string { + if isOurNick(net, name) { + return dc.nick + } + name = partialCasemap(net.casemap, name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal an entity for another network") + } + return name + } + return name + "/" + net.GetName() +} + +func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix { + if isOurNick(net, prefix.Name) { + return dc.prefix() + } + prefix.Name = partialCasemap(net.casemap, prefix.Name) + if dc.network != nil { + if dc.network != net { + panic("suika: tried to marshal a user prefix for another network") + } + return prefix + } + return &irc.Prefix{ + Name: prefix.Name + "/" + net.GetName(), + User: prefix.User, + Host: prefix.Host, + } +} + +// unmarshalEntityNetwork converts a downstream entity name (ie. channel or +// nick) into an upstream entity name. +// +// This involves removing the "/" suffix. +func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) { + if dc.network != nil { + return dc.network, name, nil + } + if !dc.isMultiUpstream { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"}, + }} + } + + var net *network + if i := strings.LastIndexByte(name, '/'); i >= 0 { + network := name[i+1:] + name = name[:i] + + for _, n := range dc.user.networks { + if network == n.GetName() { + net = n + break + } + } + } + + if net == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Missing network suffix in name"}, + }} + } + + return net, name, nil +} + +// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the +// upstream connection and fails if the upstream is disconnected. +func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) { + net, name, err := dc.unmarshalEntityNetwork(name) + if err != nil { + return nil, "", err + } + + if net.conn == nil { + return nil, "", ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "Disconnected from upstream network"}, + }} + } + + return net.conn, name, nil +} + +func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string { + if dc.upstream() != nil { + return text + } + // TODO: smarter parsing that ignores URLs + return strings.ReplaceAll(text, "/"+uc.network.GetName(), "") +} + +func (dc *downstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := dc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (dc *downstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := dc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventDownstreamMessage{msg, dc} + } + + return nil +} + +// SendMessage sends an outgoing message. +// +// This can only called from the user goroutine. +func (dc *downstreamConn) SendMessage(msg *irc.Message) { + if !dc.caps["message-tags"] { + if msg.Command == "TAGMSG" { + return + } + msg = msg.Copy() + for name := range msg.Tags { + supported := false + switch name { + case "time": + supported = dc.caps["server-time"] + case "account": + supported = dc.caps["account"] + } + if !supported { + delete(msg.Tags, name) + } + } + } + if !dc.caps["batch"] && msg.Tags["batch"] != "" { + msg = msg.Copy() + delete(msg.Tags, "batch") + } + if msg.Command == "JOIN" && !dc.caps["extended-join"] { + msg.Params = msg.Params[:1] + } + if msg.Command == "SETNAME" && !dc.caps["setname"] { + return + } + if msg.Command == "AWAY" && !dc.caps["away-notify"] { + return + } + if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] { + return + } + if msg.Command == "READ" && !dc.caps["soju.im/read"] { + return + } + + dc.conn.SendMessage(context.TODO(), msg) +} + +func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) { + dc.lastBatchRef++ + ref := fmt.Sprintf("%v", dc.lastBatchRef) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Tags: tags, + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: append([]string{"+" + ref, typ}, params...), + }) + } + + f(irc.TagValue(ref)) + + if dc.caps["batch"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BATCH", + Params: []string{"-" + ref}, + }) + } +} + +// sendMessageWithID sends an outgoing message with the specified internal ID. +func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { + dc.SendMessage(msg) + + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// advanceMessageWithID advances history to the specified message ID without +// sending a message. This is useful e.g. for self-messages when echo-message +// isn't enabled. +func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) { + if id == "" || !dc.messageSupportsBacklog(msg) { + return + } + + dc.sendPing(id) +} + +// ackMsgID acknowledges that a message has been received. +func (dc *downstreamConn) ackMsgID(id string) { + netID, entity, err := parseMsgID(id, nil) + if err != nil { + dc.logger.Printf("failed to ACK message ID %q: %v", id, err) + return + } + + network := dc.user.getNetworkByID(netID) + if network == nil { + return + } + + network.delivered.StoreID(entity, dc.clientName, id) +} + +func (dc *downstreamConn) sendPing(msgID string) { + token := "suika-msgid-" + msgID + dc.SendMessage(&irc.Message{ + Command: "PING", + Params: []string{token}, + }) +} + +func (dc *downstreamConn) handlePong(token string) { + if !strings.HasPrefix(token, "suika-msgid-") { + dc.logger.Printf("received unrecognized PONG token %q", token) + return + } + msgID := strings.TrimPrefix(token, "suika-msgid-") + dc.ackMsgID(msgID) +} + +// marshalMessage re-formats a message coming from an upstream connection so +// that it's suitable for being sent on this downstream connection. Only +// messages that may appear in logs are supported, except MODE messages which +// may only appear in single-upstream mode. +func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message { + msg = msg.Copy() + msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix) + + if dc.network != nil { + return msg + } + + switch msg.Command { + case "PRIVMSG", "NOTICE", "TAGMSG": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "NICK": + // Nick change for another user + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "JOIN", "PART": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "KICK": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + msg.Params[1] = dc.marshalEntity(net, msg.Params[1]) + case "TOPIC": + msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) + case "QUIT", "SETNAME": + // This space is intentionally left blank + default: + panic(fmt.Sprintf("unexpected %q message", msg.Command)) + } + + return msg +} + +func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + ctx, cancel := dc.conn.NewContext(ctx) + defer cancel() + + ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout) + defer cancel() + + switch msg.Command { + case "QUIT": + return dc.Close() + default: + if dc.registered { + return dc.handleMessageRegistered(ctx, msg) + } else { + return dc.handleMessageUnregistered(ctx, msg) + } + } +} + +func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "NICK": + var nick string + if err := parseMessageParams(msg, &nick); err != nil { + return err + } + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, nick, "contains illegal characters"}, + }} + } + nickCM := casemapASCII(nick) + if nickCM == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"}, + }} + } + dc.nick = nick + dc.nickCM = nickCM + case "USER": + if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil { + return err + } + case "PASS": + if err := parseMessageParams(msg, &dc.password); err != nil { + return err + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "AUTHENTICATE": + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } else if credentials == nil { + break + } + + if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil { + dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err) + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, authErrorReason(err)}, + }) + break + } + + // Technically we should send RPL_LOGGEDIN here. However we use + // RPL_LOGGEDIN to mirror the upstream connection status. Let's + // see how many clients that breaks. See: + // https://github.com/ircv3/ircv3-specifications/pull/476 + dc.endSASL(nil) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + + if dc.user == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"}, + }} + } + + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + var match *network + for _, net := range dc.user.networks { + if net.ID == id { + match = net + break + } + } + if match == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"}, + }} + } + + dc.networkName = match.GetName() + } + default: + dc.logger.Printf("unhandled message: %v", msg) + return newUnknownCommandError(msg.Command) + } + if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps { + return dc.register(ctx) + } + return nil +} + +func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { + cmd = strings.ToUpper(cmd) + + switch cmd { + case "LS": + if len(args) > 0 { + var err error + if dc.capVersion, err = strconv.Atoi(args[0]); err != nil { + return err + } + } + if !dc.registered && dc.capVersion >= 302 { + // Let downstream show everything it supports, and trim + // down the available capabilities when upstreams are + // known. + for k, v := range needAllDownstreamCaps { + dc.supportedCaps[k] = v + } + } + + caps := make([]string, 0, len(dc.supportedCaps)) + for k, v := range dc.supportedCaps { + if dc.capVersion >= 302 && v != "" { + caps = append(caps, k+"="+v) + } else { + caps = append(caps, k) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LS", strings.Join(caps, " ")}, + }) + + if dc.capVersion >= 302 { + // CAP version 302 implicitly enables cap-notify + dc.caps["cap-notify"] = true + } + + if !dc.registered { + dc.negotiatingCaps = true + } + case "LIST": + var caps []string + for name, enabled := range dc.caps { + if enabled { + caps = append(caps, name) + } + } + + // TODO: multi-line replies + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "LIST", strings.Join(caps, " ")}, + }) + case "REQ": + if len(args) == 0 { + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"}, + }} + } + + // TODO: atomically ack/nak the whole capability set + caps := strings.Fields(args[0]) + ack := true + for _, name := range caps { + name = strings.ToLower(name) + enable := !strings.HasPrefix(name, "-") + if !enable { + name = strings.TrimPrefix(name, "-") + } + + if enable == dc.caps[name] { + continue + } + + _, ok := dc.supportedCaps[name] + if !ok { + ack = false + break + } + + if name == "cap-notify" && dc.capVersion >= 302 && !enable { + // cap-notify cannot be disabled with CAP version 302 + ack = false + break + } + + dc.caps[name] = enable + } + + reply := "NAK" + if ack { + reply = "ACK" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, reply, args[0]}, + }) + + if !dc.registered { + dc.negotiatingCaps = true + } + case "END": + dc.negotiatingCaps = false + default: + return ircError{&irc.Message{ + Command: err_invalidcapcmd, + Params: []string{dc.nick, cmd, "Unknown CAP command"}, + }} + } + return nil +} + +func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) { + defer func() { + if err != nil { + dc.sasl = nil + } + }() + + if !dc.caps["sasl"] { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, + }} + } + if len(msg.Params) == 0 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Missing AUTHENTICATE argument"}, + }} + } + if msg.Params[0] == "*" { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }} + } + + var resp []byte + if dc.sasl == nil { + mech := strings.ToUpper(msg.Params[0]) + var server sasl.Server + switch mech { + case "PLAIN": + server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error { + dc.sasl.plainUsername = username + dc.sasl.plainPassword = password + return nil + })) + default: + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)}, + }} + } + + dc.sasl = &downstreamSASL{server: server} + } else { + chunk := msg.Params[0] + if chunk == "+" { + chunk = "" + } + + if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Response too long"}, + }} + } + + dc.sasl.pendingResp.WriteString(chunk) + + if len(chunk) == maxSASLLength { + return nil, nil // Multi-line response, wait for the next command + } + + resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String()) + if err != nil { + return nil, ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Invalid base64-encoded response"}, + }} + } + + dc.sasl.pendingResp.Reset() + } + + challenge, done, err := dc.sasl.server.Next(resp) + if err != nil { + return nil, err + } else if done { + return dc.sasl, nil + } else { + challengeStr := "+" + if len(challenge) > 0 { + challengeStr = base64.StdEncoding.EncodeToString(challenge) + } + + // TODO: multi-line messages + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "AUTHENTICATE", + Params: []string{challengeStr}, + }) + return nil, nil + } +} + +func (dc *downstreamConn) endSASL(msg *irc.Message) { + if dc.sasl == nil { + return + } + + dc.sasl = nil + + if msg != nil { + dc.SendMessage(msg) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_SASLSUCCESS, + Params: []string{dc.nick, "SASL authentication successful"}, + }) + } +} + +func (dc *downstreamConn) setSupportedCap(name, value string) { + prevValue, hasPrev := dc.supportedCaps[name] + changed := !hasPrev || prevValue != value + dc.supportedCaps[name] = value + + if !dc.caps["cap-notify"] || !changed { + return + } + + cap := name + if value != "" && dc.capVersion >= 302 { + cap = name + "=" + value + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "NEW", cap}, + }) +} + +func (dc *downstreamConn) unsetSupportedCap(name string) { + _, hasPrev := dc.supportedCaps[name] + delete(dc.supportedCaps, name) + delete(dc.caps, name) + + if !dc.caps["cap-notify"] || !hasPrev { + return + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "CAP", + Params: []string{dc.nick, "DEL", name}, + }) +} + +func (dc *downstreamConn) updateSupportedCaps() { + supportedCaps := make(map[string]bool) + for cap := range needAllDownstreamCaps { + supportedCaps[cap] = true + } + dc.forEachUpstream(func(uc *upstreamConn) { + for cap, supported := range supportedCaps { + supportedCaps[cap] = supported && uc.caps[cap] + } + }) + + for cap, supported := range supportedCaps { + if supported { + dc.setSupportedCap(cap, needAllDownstreamCaps[cap]) + } else { + dc.unsetSupportedCap(cap) + } + } + + if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") { + dc.setSupportedCap("sasl", "PLAIN") + } else if dc.network != nil { + dc.unsetSupportedCap("sasl") + } + + if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] { + // Strip "before-connect", because we require downstreams to be fully + // connected before attempting account registration. + values := strings.Split(uc.supportedCaps["draft/account-registration"], ",") + for i, v := range values { + if v == "before-connect" { + values = append(values[:i], values[i+1:]...) + break + } + } + dc.setSupportedCap("draft/account-registration", strings.Join(values, ",")) + } else { + dc.unsetSupportedCap("draft/account-registration") + } + + if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil { + dc.setSupportedCap("draft/event-playback", "") + } else { + dc.unsetSupportedCap("draft/event-playback") + } +} + +func (dc *downstreamConn) updateNick() { + if uc := dc.upstream(); uc != nil && uc.nick != dc.nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{uc.nick}, + }) + dc.nick = uc.nick + dc.nickCM = casemapASCII(dc.nick) + } +} + +func (dc *downstreamConn) updateRealname() { + if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{uc.realname}, + }) + dc.realname = uc.realname + } +} + +func (dc *downstreamConn) updateAccount() { + var account string + if dc.network == nil { + account = dc.user.Username + } else if uc := dc.upstream(); uc != nil { + account = uc.account + } else { + return + } + + if dc.account == account || !dc.caps["sasl"] { + return + } + + if account != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDIN, + Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LOGGEDOUT, + Params: []string{dc.nick, dc.prefix().String(), "You are logged out"}, + }) + } + + dc.account = account +} + +func sanityCheckServer(ctx context.Context, addr string) error { + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr) + if err != nil { + return err + } + + return conn.Close() +} + +func unmarshalUsername(rawUsername string) (username, client, network string) { + username = rawUsername + + i := strings.IndexAny(username, "/@") + j := strings.LastIndexAny(username, "/@") + if i >= 0 { + username = rawUsername[:i] + } + if j >= 0 { + if rawUsername[j] == '@' { + client = rawUsername[j+1:] + } else { + network = rawUsername[j+1:] + } + } + if i >= 0 && j >= 0 && i < j { + if rawUsername[i] == '@' { + client = rawUsername[i+1 : j] + } else { + network = rawUsername[i+1 : j] + } + } + + return username, client, network +} + +func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error { + username, clientName, networkName := unmarshalUsername(username) + + u, err := dc.srv.db.GetUser(ctx, username) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err)) + } + + // Password auth disabled + if u.Password == "" { + return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled")) + } + + err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) + if err != nil { + return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password")) + } + + dc.user = dc.srv.getUser(username) + if dc.user == nil { + return fmt.Errorf("user not active") + } + dc.clientName = clientName + dc.networkName = networkName + return nil +} + +func (dc *downstreamConn) register(ctx context.Context) error { + if dc.registered { + panic("tried to register twice") + } + + if dc.sasl != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + } + + password := dc.password + dc.password = "" + if dc.user == nil { + if password == "" { + if dc.caps["sasl"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"}, + }} + } else { + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, "Authentication required"}, + }} + } + } + + if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil { + dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, authErrorReason(err)}, + }} + } + } + + _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername) + if dc.clientName == "" { + dc.clientName = fallbackClientName + } else if fallbackClientName != "" && dc.clientName != fallbackClientName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Client name mismatch in usernames"}, + }} + } + if dc.networkName == "" { + dc.networkName = fallbackNetworkName + } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, "Network name mismatch in usernames"}, + }} + } + + dc.registered = true + dc.logger.Printf("registration complete for user %q", dc.user.Username) + return nil +} + +func (dc *downstreamConn) loadNetwork(ctx context.Context) error { + if dc.networkName == "" { + return nil + } + + network := dc.user.getNetwork(dc.networkName) + if network == nil { + addr := dc.networkName + if !strings.ContainsRune(addr, ':') { + addr = addr + ":6697" + } + + dc.logger.Printf("trying to connect to new network %q", addr) + if err := sanityCheckServer(ctx, addr); err != nil { + dc.logger.Printf("failed to connect to %q: %v", addr, err) + return ircError{&irc.Message{ + Command: irc.ERR_PASSWDMISMATCH, + Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)}, + }} + } + + // Some clients only allow specifying the nickname (and use the + // nickname as a username too). Strip the network name from the + // nickname when auto-saving networks. + nick, _, _ := unmarshalUsername(dc.nick) + + dc.logger.Printf("auto-saving network %q", dc.networkName) + var err error + network, err = dc.user.createNetwork(ctx, &Network{ + Addr: dc.networkName, + Nick: nick, + Enabled: true, + }) + if err != nil { + return err + } + } + + dc.network = network + return nil +} + +func (dc *downstreamConn) welcome(ctx context.Context) error { + if dc.user == nil || !dc.registered { + panic("tried to welcome an unregistered connection") + } + + remoteAddr := dc.conn.RemoteAddr().String() + dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)} + + // TODO: doing this might take some time. We should do it in dc.register + // instead, but we'll potentially be adding a new network and this must be + // done in the user goroutine. + if err := dc.loadNetwork(ctx); err != nil { + return err + } + + if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream { + dc.isMultiUpstream = true + } + + dc.updateSupportedCaps() + + isupport := []string{ + fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit), + "CASEMAPPING=ascii", + } + + if dc.network != nil { + isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID)) + } + if title := dc.srv.Config().Title; dc.network == nil && title != "" { + isupport = append(isupport, "NETWORK="+encodeISUPPORT(title)) + } + if dc.network == nil && !dc.isMultiUpstream { + isupport = append(isupport, "WHOX") + } + + if uc := dc.upstream(); uc != nil { + for k := range passthroughIsupport { + v, ok := uc.isupport[k] + if !ok { + continue + } + if v != nil { + isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v)) + } else { + isupport = append(isupport, k) + } + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WELCOME, + Params: []string{dc.nick, "Welcome to suika, " + dc.nick}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_YOURHOST, + Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MYINFO, + Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) { + dc.SendMessage(msg) + } + if uc := dc.upstream(); uc != nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + string(uc.modes)}, + }) + } + if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+o"}, + }) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + + if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil { + for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) { + dc.SendMessage(msg) + } + } else { + motdHint := "No MOTD" + if dc.network != nil { + motdHint = "Use /motd to read the message of the day" + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOMOTD, + Params: []string{dc.nick, motdHint}, + }) + } + + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + } + + dc.forEachUpstream(func(uc *upstreamConn) { + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if !ch.complete { + continue + } + record := uc.network.channels.Value(ch.Name) + if record != nil && record.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)}, + }) + + forwardChannel(ctx, dc, ch) + } + }) + + dc.forEachNetwork(func(net *network) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + // Only send history if we're the first connected client with that name + // for the network + firstClient := true + dc.user.forEachDownstream(func(c *downstreamConn) { + if c != dc && c.clientName == dc.clientName && c.network == dc.network { + firstClient = false + } + }) + if firstClient { + net.delivered.ForEachTarget(func(target string) { + lastDelivered := net.delivered.LoadID(target, dc.clientName) + if lastDelivered == "" { + return + } + + dc.sendTargetBacklog(ctx, net, target, lastDelivered) + + // Fast-forward history to last message + targetCM := net.casemap(target) + lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now()) + if err != nil { + dc.logger.Printf("failed to get last message ID: %v", err) + return + } + net.delivered.StoreID(target, dc.clientName, lastID) + }) + } + }) + + return nil +} + +// messageSupportsBacklog checks whether the provided message can be sent as +// part of an history batch. +func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool { + // Don't replay all messages, because that would mess up client + // state. For instance we just sent the list of users, sending + // PART messages for one of these users would be incorrect. + switch msg.Command { + case "PRIVMSG", "NOTICE": + return true + } + return false +} + +func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) { + if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { + return + } + + ch := net.channels.Value(target) + + ctx, cancel := context.WithTimeout(ctx, backlogTimeout) + defer cancel() + + targetCM := net.casemap(target) + history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit) + if err != nil { + dc.logger.Printf("failed to send backlog for %q: %v", target, err) + return + } + + dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + if ch != nil && ch.Detached { + if net.detachedMessageNeedsRelay(ch, msg) { + dc.relayDetachedMessage(net, msg) + } + } else { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, net)) + } + } + }) +} + +func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return + } + + sender := msg.Prefix.Name + target, text := msg.Params[0], msg.Params[1] + if net.isHighlight(msg) { + sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } else { + sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text)) + } +} + +func (dc *downstreamConn) runUntilRegistered() error { + ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout) + defer cancel() + + // Close the connection with an error if the deadline is exceeded + go func() { + <-ctx.Done() + if err := ctx.Err(); err == context.DeadlineExceeded { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "ERROR", + Params: []string{"Connection registration timed out"}, + }) + dc.Close() + } + }() + + for !dc.registered { + msg, err := dc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read IRC command: %w", err) + } + + err = dc.handleMessage(ctx, msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + return fmt.Errorf("failed to handle IRC command %q: %v", msg, err) + } + } + + return nil +} + +func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error { + switch msg.Command { + case "CAP": + var subCmd string + if err := parseMessageParams(msg, &subCmd); err != nil { + return err + } + if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil { + return err + } + case "PING": + var source, destination string + if err := parseMessageParams(msg, &source); err != nil { + return err + } + if len(msg.Params) > 1 { + destination = msg.Params[1] + } + hostname := dc.srv.Config().Hostname + if destination != "" && destination != hostname { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHSERVER, + Params: []string{dc.nick, destination, "No such server"}, + }} + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "PONG", + Params: []string{hostname, source}, + }) + return nil + case "PONG": + if len(msg.Params) == 0 { + return newNeedMoreParamsError(msg.Command) + } + token := msg.Params[len(msg.Params)-1] + dc.handlePong(token) + case "USER": + return ircError{&irc.Message{ + Command: irc.ERR_ALREADYREGISTERED, + Params: []string{dc.nick, "You may not reregister"}, + }} + case "NICK": + var rawNick string + if err := parseMessageParams(msg, &rawNick); err != nil { + return err + } + + nick := rawNick + var upstream *upstreamConn + if dc.upstream() == nil { + uc, unmarshaledNick, err := dc.unmarshalEntity(nick) + if err == nil { // NICK nick/network: NICK only on a specific upstream + upstream = uc + nick = unmarshaledNick + } + } + + if nick == "" || strings.ContainsAny(nick, illegalNickChars) { + return ircError{&irc.Message{ + Command: irc.ERR_ERRONEUSNICKNAME, + Params: []string{dc.nick, rawNick, "contains illegal characters"}, + }} + } + if casemapASCII(nick) == serviceNickCM { + return ircError{&irc.Message{ + Command: irc.ERR_NICKNAMEINUSE, + Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"}, + }} + } + + var err error + dc.forEachNetwork(func(n *network) { + if err != nil || (upstream != nil && upstream.network != n) { + return + } + n.Nick = nick + err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network) + }) + if err != nil { + return err + } + + dc.forEachUpstream(func(uc *upstreamConn) { + if upstream != nil && upstream != uc { + return + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NICK", + Params: []string{nick}, + }) + }) + + if dc.upstream() == nil && upstream == nil && dc.nick != nick { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "NICK", + Params: []string{nick}, + }) + dc.nick = nick + dc.nickCM = casemapASCII(dc.nick) + } + case "SETNAME": + var realname string + if err := parseMessageParams(msg, &realname); err != nil { + return err + } + + // If the client just resets to the default, just wipe the per-network + // preference + storeRealname := realname + if realname == dc.user.Realname { + storeRealname = "" + } + + var storeErr error + var needUpdate []Network + dc.forEachNetwork(func(n *network) { + // We only need to call updateNetwork for upstreams that don't + // support setname + if uc := n.conn; uc != nil && uc.caps["setname"] { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "SETNAME", + Params: []string{realname}, + }) + + n.Realname = storeRealname + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil { + dc.logger.Printf("failed to store network realname: %v", err) + storeErr = err + } + return + } + + record := n.Network // copy network record because we'll mutate it + record.Realname = storeRealname + needUpdate = append(needUpdate, record) + }) + + // Walk the network list as a second step, because updateNetwork + // mutates the original list + for _, record := range needUpdate { + if _, err := dc.user.updateNetwork(ctx, &record); err != nil { + dc.logger.Printf("failed to update network realname: %v", err) + storeErr = err + } + } + if storeErr != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"}, + }} + } + + if dc.upstream() == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "SETNAME", + Params: []string{realname}, + }) + } + case "JOIN": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var keys []string + if len(msg.Params) > 1 { + keys = strings.Split(msg.Params[1], ",") + } + + for i, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + var key string + if len(keys) > i { + key = keys[i] + } + + if !uc.isChannel(upstreamName) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{name, "Not a channel name"}, + }) + continue + } + + // Most servers ignore duplicate JOIN messages. We ignore them here + // because some clients automatically send JOIN messages in bulk + // when reconnecting to the bouncer. We don't want to flood the + // upstream connection with these. + if !uc.channels.Has(upstreamName) { + params := []string{upstreamName} + if key != "" { + params = append(params, key) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "JOIN", + Params: params, + }) + } + + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + // Don't clear the channel key if there's one set + // TODO: add a way to unset the channel key + if key != "" { + ch.Key = key + } + uc.network.attach(ctx, ch) + } else { + ch = &Channel{ + Name: upstreamName, + Key: key, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } + case "PART": + var namesStr string + if err := parseMessageParams(msg, &namesStr); err != nil { + return err + } + + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + + for _, name := range strings.Split(namesStr, ",") { + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if strings.EqualFold(reason, "detach") { + ch := uc.network.channels.Value(upstreamName) + if ch != nil { + uc.network.detach(ch) + } else { + ch = &Channel{ + Name: name, + Detached: true, + } + uc.network.channels.SetValue(upstreamName, ch) + } + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + } + } else { + params := []string{upstreamName} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "PART", + Params: params, + }) + + if err := uc.network.deleteChannel(ctx, upstreamName); err != nil { + dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) + } + } + } + case "KICK": + var channelStr, userStr string + if err := parseMessageParams(msg, &channelStr, &userStr); err != nil { + return err + } + + channels := strings.Split(channelStr, ",") + users := strings.Split(userStr, ",") + + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + + if len(channels) != 1 && len(channels) != len(users) { + return ircError{&irc.Message{ + Command: irc.ERR_BADCHANMASK, + Params: []string{dc.nick, channelStr, "Bad channel mask"}, + }} + } + + for i, user := range users { + var channel string + if len(channels) == 1 { + channel = channels[0] + } else { + channel = channels[i] + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + params := []string{upstreamChannel, upstreamUser} + if reason != "" { + params = append(params, reason) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "KICK", + Params: params, + }) + } + case "MODE": + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + + var modeStr string + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + if casemapASCII(name) == dc.nickCM { + if modeStr != "" { + if uc := dc.upstream(); uc != nil { + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: []string{uc.nick, modeStr}, + }) + } else { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_UMODEUNKNOWNFLAG, + Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"}, + }) + } + } else { + var userMode string + if uc := dc.upstream(); uc != nil { + userMode = string(uc.modes) + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_UMODEIS, + Params: []string{dc.nick, "+" + userMode}, + }) + } + return nil + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if !uc.isChannel(upstreamName) { + return ircError{&irc.Message{ + Command: irc.ERR_USERSDONTMATCH, + Params: []string{dc.nick, "Cannot change mode for other users"}, + }} + } + + if modeStr != "" { + params := []string{upstreamName, modeStr} + params = append(params, msg.Params[2:]...) + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "MODE", + Params: params, + }) + } else { + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, name, "No such channel"}, + }} + } + + if ch.modes == nil { + // we haven't received the initial RPL_CHANNELMODEIS yet + // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS + return nil + } + + modeStr, modeParams := ch.modes.Format() + params := []string{dc.nick, name, modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + if ch.creationTime != "" { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, name, ch.creationTime}, + }) + } + } + case "TOPIC": + var channel string + if err := parseMessageParams(msg, &channel); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + if len(msg.Params) > 1 { // setting topic + topic := msg.Params[1] + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "TOPIC", + Params: []string{upstreamName, topic}, + }) + } else { // getting topic + ch := uc.channels.Value(upstreamName) + if ch == nil { + return ircError{&irc.Message{ + Command: irc.ERR_NOSUCHCHANNEL, + Params: []string{dc.nick, upstreamName, "No such channel"}, + }} + } + sendTopic(dc, ch) + } + case "LIST": + network := dc.network + if network == nil && len(msg.Params) > 0 { + var err error + network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0]) + if err != nil { + return err + } + } + if network == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"}, + }) + return nil + } + + uc := network.conn + if uc == nil { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Disconnected from upstream server"}, + }) + return nil + } + + uc.enqueueCommand(dc, msg) + case "NAMES": + if len(msg.Params) == 0 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, "*", "End of /NAMES list"}, + }) + return nil + } + + channels := strings.Split(msg.Params[0], ",") + for _, channel := range channels { + uc, upstreamName, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ch := uc.channels.Value(upstreamName) + if ch != nil { + sendNames(dc, ch) + } else { + // NAMES on a channel we have not joined, ask upstream + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "NAMES", + Params: []string{upstreamName}, + }) + } + } + // For WHOX docs, see: + // - http://faerion.sourceforge.net/doc/irc/whox.var + // - https://github.com/quakenet/snircd/blob/master/doc/readme.who + // Note, many features aren't widely implemented, such as flags and mask2 + case "WHO": + if len(msg.Params) == 0 { + // TODO: support WHO without parameters + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, "*", "End of /WHO list"}, + }) + return nil + } + + // Clients will use the first mask to match RPL_ENDOFWHO + endOfWhoToken := msg.Params[0] + + // TODO: add support for WHOX mask2 + mask := msg.Params[0] + var options string + if len(msg.Params) > 1 { + options = msg.Params[1] + } + + optionsParts := strings.SplitN(options, "%", 2) + // TODO: add support for WHOX flags in optionsParts[0] + var fields, whoxToken string + if len(optionsParts) == 2 { + optionsParts := strings.SplitN(optionsParts[1], ",", 2) + fields = strings.ToLower(optionsParts[0]) + if len(optionsParts) == 2 && strings.Contains(fields, "t") { + whoxToken = optionsParts[1] + } + } + + // TODO: support mixed bouncer/upstream WHO queries + maskCM := casemapASCII(mask) + if dc.network == nil && maskCM == dc.nickCM { + // TODO: support AWAY (H/G) in self WHO reply + flags := "H" + if dc.user.Admin { + flags += "*" + } + info := whoxInfo{ + Token: whoxToken, + Username: dc.user.Username, + Hostname: dc.hostname, + Server: dc.srv.Config().Hostname, + Nickname: dc.nick, + Flags: flags, + Account: dc.user.Username, + Realname: dc.realname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + if maskCM == serviceNickCM { + info := whoxInfo{ + Token: whoxToken, + Username: servicePrefix.User, + Hostname: servicePrefix.Host, + Server: dc.srv.Config().Hostname, + Nickname: serviceNick, + Flags: "H*", + Account: serviceNick, + Realname: serviceRealname, + } + dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, + }) + return nil + } + + // TODO: properly support WHO masks + uc, upstreamMask, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + params := []string{upstreamMask} + if options != "" { + params = append(params, options) + } + + uc.enqueueCommand(dc, &irc.Message{ + Command: "WHO", + Params: params, + }) + case "WHOIS": + if len(msg.Params) == 0 { + return ircError{&irc.Message{ + Command: irc.ERR_NONICKNAMEGIVEN, + Params: []string{dc.nick, "No nickname given"}, + }} + } + + var target, mask string + if len(msg.Params) == 1 { + target = "" + mask = msg.Params[0] + } else { + target = msg.Params[0] + mask = msg.Params[1] + } + // TODO: support multiple WHOIS users + if i := strings.IndexByte(mask, ','); i >= 0 { + mask = mask[:i] + } + + if dc.network == nil && casemapASCII(mask) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"}, + }) + if dc.user.Admin { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, dc.nick, "is a bouncer administrator"}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, dc.nick, "End of /WHOIS list"}, + }) + return nil + } + if casemapASCII(mask) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, serviceNick, "is the bouncer service"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_whoisaccount, + Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"}, + }) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, serviceNick, "End of /WHOIS list"}, + }) + return nil + } + + // TODO: support WHOIS masks + uc, upstreamNick, err := dc.unmarshalEntity(mask) + if err != nil { + return err + } + + var params []string + if target != "" { + if target == mask { // WHOIS nick nick + params = []string{upstreamNick, upstreamNick} + } else { + params = []string{target, upstreamNick} + } + } else { + params = []string{upstreamNick} + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "WHOIS", + Params: params, + }) + case "PRIVMSG", "NOTICE": + var targetsStr, text string + if err := parseMessageParams(msg, &targetsStr, &text); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) { + // "$" means a server mask follows. If it's the bouncer's + // hostname, broadcast the message to all bouncer users. + if !dc.user.Admin { + return ircError{&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_BADMASK, + Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"}, + }} + } + + dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text) + + broadcastTags := tags.Copy() + broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now())) + broadcastMsg := &irc.Message{ + Tags: broadcastTags, + Prefix: servicePrefix, + Command: msg.Command, + Params: []string{name, text}, + } + dc.srv.forEachUser(func(u *user) { + u.events <- eventBroadcast{broadcastMsg} + }) + continue + } + + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + continue + } + + if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM { + if dc.caps["echo-message"] { + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + dc.SendMessage(&irc.Message{ + Tags: echoTags, + Prefix: dc.prefix(), + Command: msg.Command, + Params: []string{name, text}, + }) + } + handleServicePRIVMSG(ctx, dc, text) + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + + if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" { + dc.handleNickServPRIVMSG(ctx, uc, text) + } + + unmarshaledText := text + if uc.isChannel(upstreamName) { + unmarshaledText = dc.unmarshalText(uc, text) + } + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: msg.Command, + Params: []string{upstreamName, unmarshaledText}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: msg.Command, + Params: []string{upstreamName, text}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "TAGMSG": + var targetsStr string + if err := parseMessageParams(msg, &targetsStr); err != nil { + return err + } + tags := copyClientTags(msg.Tags) + + for _, name := range strings.Split(targetsStr, ",") { + if dc.network == nil && casemapASCII(name) == dc.nickCM { + dc.SendMessage(&irc.Message{ + Tags: msg.Tags.Copy(), + Prefix: dc.prefix(), + Command: "TAGMSG", + Params: []string{name}, + }) + continue + } + + if casemapASCII(name) == serviceNickCM { + continue + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return err + } + if _, ok := uc.caps["message-tags"]; !ok { + continue + } + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Tags: tags, + Command: "TAGMSG", + Params: []string{upstreamName}, + }) + + echoTags := tags.Copy() + echoTags["time"] = irc.TagValue(formatServerTime(time.Now())) + if uc.account != "" { + echoTags["account"] = irc.TagValue(uc.account) + } + echoMsg := &irc.Message{ + Tags: echoTags, + Prefix: &irc.Prefix{Name: uc.nick}, + Command: "TAGMSG", + Params: []string{upstreamName}, + } + uc.produce(upstreamName, echoMsg, dc) + + uc.updateChannelAutoDetach(upstreamName) + } + case "INVITE": + var user, channel string + if err := parseMessageParams(msg, &user, &channel); err != nil { + return err + } + + ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + + ucUser, upstreamUser, err := dc.unmarshalEntity(user) + if err != nil { + return err + } + + if ucChannel != ucUser { + return ircError{&irc.Message{ + Command: irc.ERR_USERNOTINCHANNEL, + Params: []string{dc.nick, user, channel, "They are on another network"}, + }} + } + uc := ucChannel + + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ + Command: "INVITE", + Params: []string{upstreamUser, upstreamChannel}, + }) + case "AUTHENTICATE": + // Post-connection-registration AUTHENTICATE is unsupported in + // multi-upstream mode, or if the upstream doesn't support SASL + uc := dc.upstream() + if uc == nil || !uc.caps["sasl"] { + return ircError{&irc.Message{ + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Upstream network authentication not supported"}, + }} + } + + credentials, err := dc.handleAuthenticateCommand(msg) + if err != nil { + return err + } + + if credentials != nil { + if uc.saslClient != nil { + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLFAIL, + Params: []string{dc.nick, "Another authentication attempt is already in progress"}, + }) + return nil + } + + uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername) + uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword) + uc.enqueueCommand(dc, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"PLAIN"}, + }) + } + case "REGISTER", "VERIFY": + // Check number of params here, since we'll use that to save the + // credentials on command success + if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) { + return newNeedMoreParamsError(msg.Command) + } + + uc := dc.upstream() + if uc == nil || !uc.caps["draft/account-registration"] { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"}, + }} + } + + uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0]) + uc.enqueueCommand(dc, msg) + case "MONITOR": + // MONITOR is unsupported in multi-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + if _, ok := uc.isupport["MONITOR"]; !ok { + return newUnknownCommandError(msg.Command) + } + + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "+", "-": + var targets string + if err := parseMessageParams(msg, nil, &targets); err != nil { + return err + } + for _, target := range strings.Split(targets, ",") { + if subcommand == "+" { + // Hard limit, just to avoid having downstreams fill our map + if len(dc.monitored.innerMap) >= 1000 { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_MONLISTFULL, + Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"}, + }) + continue + } + + dc.monitored.SetValue(target, nil) + + if uc.monitored.Has(target) { + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } else { + dc.monitored.Delete(target) + } + } + uc.updateMonitor() + case "C": // clear + dc.monitored = newCasemapMap(0) + uc.updateMonitor() + case "L": // list + // TODO: be less lazy and pack the list + for _, entry := range dc.monitored.innerMap { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_MONLIST, + Params: []string{dc.nick, entry.originalKey}, + }) + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFMONLIST, + Params: []string{dc.nick, "End of MONITOR list"}, + }) + case "S": // status + // TODO: be less lazy and pack the lists + for _, entry := range dc.monitored.innerMap { + target := entry.originalKey + + cmd := irc.RPL_MONOFFLINE + if online := uc.monitored.Value(target); online { + cmd = irc.RPL_MONONLINE + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: cmd, + Params: []string{dc.nick, target}, + }) + } + } + case "CHATHISTORY": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + var target, limitStr string + var boundsStr [2]string + switch subcommand { + case "AFTER", "BEFORE", "LATEST": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil { + return err + } + case "BETWEEN": + if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + case "TARGETS": + if dc.network == nil { + // Either an unbound bouncer network, in which case we should return no targets, + // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet. + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {}) + return nil + } + if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { + return err + } + default: + // TODO: support AROUND + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"}, + }} + } + + // We don't save history for our service + if casemapASCII(target) == serviceNickCM { + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {}) + return nil + } + + store, ok := dc.user.msgStore.(chatHistoryMessageStore) + if !ok { + return ircError{&irc.Message{ + Command: irc.ERR_UNKNOWNCOMMAND, + Params: []string{dc.nick, "CHATHISTORY", "Unknown command"}, + }} + } + + network, entity, err := dc.unmarshalEntityNetwork(target) + if err != nil { + return err + } + entity = network.casemap(entity) + + // TODO: support msgid criteria + var bounds [2]time.Time + bounds[0] = parseChatHistoryBound(boundsStr[0]) + if subcommand == "LATEST" && boundsStr[0] == "*" { + bounds[0] = time.Now() + } else if bounds[0].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"}, + }} + } + + if boundsStr[1] != "" { + bounds[1] = parseChatHistoryBound(boundsStr[1]) + if bounds[1].IsZero() { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"}, + }} + } + } + + limit, err := strconv.Atoi(limitStr) + if err != nil || limit < 0 || limit > chatHistoryLimit { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"}, + }} + } + + eventPlayback := dc.caps["draft/event-playback"] + + var history []*irc.Message + switch subcommand { + case "BEFORE", "LATEST": + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback) + case "AFTER": + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback) + case "BETWEEN": + if bounds[0].Before(bounds[1]) { + history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } else { + history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback) + } + case "TARGETS": + // TODO: support TARGETS in multi-upstream mode + targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback) + if err != nil { + dc.logger.Printf("failed fetching targets for chathistory: %v", err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"}, + }} + } + + dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) { + for _, target := range targets { + if ch := network.channels.Value(target.Name); ch != nil && ch.Detached { + continue + } + + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "CHATHISTORY", + Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)}, + }) + } + }) + + return nil + } + if err != nil { + dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err) + return newChatHistoryError(subcommand, target) + } + + dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) { + for _, msg := range history { + msg.Tags["batch"] = batchRef + dc.SendMessage(dc.marshalMessage(msg, network)) + } + }) + case "READ": + var target, criteria string + if err := parseMessageParams(msg, &target); err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"}, + }} + } + if len(msg.Params) > 1 { + criteria = msg.Params[1] + } + + // We don't save read receipts for our service + if casemapASCII(target) == serviceNickCM { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "READ", + Params: []string{target, "*"}, + }) + return nil + } + + uc, entity, err := dc.unmarshalEntity(target) + if err != nil { + return err + } + entityCM := uc.network.casemap(entity) + + r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM) + if err != nil { + dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } else if r == nil { + r = &ReadReceipt{ + Target: entityCM, + } + } + + broadcast := false + if len(criteria) > 0 { + // TODO: support msgid criteria + criteriaParts := strings.SplitN(criteria, "=", 2) + if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"}, + }} + } + + timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1]) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"}, + }} + } + now := time.Now() + if timestamp.After(now) { + timestamp = now + } + if r.Timestamp.Before(timestamp) { + r.Timestamp = timestamp + if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil { + dc.logger.Printf("failed to store receipt for %q: %v", entity, err) + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, + }} + } + broadcast = true + } + } + + timestampStr := "*" + if !r.Timestamp.IsZero() { + timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp)) + } + uc.forEachDownstream(func(d *downstreamConn) { + if broadcast || dc.id == d.id { + d.SendMessage(&irc.Message{ + Prefix: d.prefix(), + Command: "READ", + Params: []string{d.marshalEntity(uc.network, entity), timestampStr}, + }) + } + }) + case "BOUNCER": + var subcommand string + if err := parseMessageParams(msg, &subcommand); err != nil { + return err + } + + switch strings.ToUpper(subcommand) { + case "BIND": + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"}, + }} + case "LISTNETWORKS": + dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) { + for _, network := range dc.user.networks { + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + dc.SendMessage(&irc.Message{ + Tags: irc.Tags{"batch": batchRef}, + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + case "ADDNETWORK": + var attrsStr string + if err := parseMessageParams(msg, nil, &attrsStr); err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + record := &Network{Nick: dc.nick, Enabled: true} + if err := updateNetworkAttrs(record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)}, + }) + case "CHANGENETWORK": + var idStr, attrsStr string + if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + attrs := irc.ParseTags(attrsStr) + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + record := net.Network // copy network record because we'll mutate it + if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil { + return err + } + + if record.Nick == dc.user.Username { + record.Nick = "" + } + if record.Realname == dc.user.Realname { + record.Realname = "" + } + + _, err = dc.user.updateNetwork(ctx, &record) + if err != nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)}, + }} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"CHANGENETWORK", idStr}, + }) + case "DELNETWORK": + var idStr string + if err := parseMessageParams(msg, nil, &idStr); err != nil { + return err + } + id, err := parseBouncerNetID(subcommand, idStr) + if err != nil { + return err + } + + net := dc.user.getNetworkByID(id) + if net == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"}, + }} + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"DELNETWORK", idStr}, + }) + default: + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"}, + }} + } + default: + dc.logger.Printf("unhandled message: %v", msg) + + // Only forward unknown commands in single-upstream mode + uc := dc.upstream() + if uc == nil { + return newUnknownCommandError(msg.Command) + } + + uc.SendMessageLabeled(ctx, dc.id, msg) + } + return nil +} + +func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) { + username, password, ok := parseNickServCredentials(text, uc.nick) + if ok { + uc.network.autoSaveSASLPlain(ctx, username, password) + } +} + +func parseNickServCredentials(text, nick string) (username, password string, ok bool) { + fields := strings.Fields(text) + if len(fields) < 2 { + return "", "", false + } + cmd := strings.ToUpper(fields[0]) + params := fields[1:] + switch cmd { + case "REGISTER": + username = nick + password = params[0] + case "IDENTIFY": + if len(params) == 1 { + username = nick + password = params[0] + } else { + username = params[0] + password = params[1] + } + case "SET": + if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") { + username = nick + password = params[1] + } + default: + return "", "", false + } + return username, password, true +} diff --git a/trunk/go.mod b/trunk/go.mod new file mode 100644 index 0000000..cc08167 --- /dev/null +++ b/trunk/go.mod @@ -0,0 +1,39 @@ +module marisa.chaotic.ninja/suika + +go 1.20 + +require ( + git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 + git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 + github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead + github.com/lib/pq v1.10.7 + golang.org/x/crypto v0.7.0 + golang.org/x/term v0.6.0 + golang.org/x/time v0.3.0 + gopkg.in/irc.v3 v3.1.4 + modernc.org/sqlite v1.21.0 +) + +require ( + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/stretchr/testify v1.8.0 // indirect + golang.org/x/mod v0.3.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + lukechampine.com/uint128 v1.2.0 // indirect + modernc.org/cc/v3 v3.40.0 // indirect + modernc.org/ccgo/v3 v3.16.13 // indirect + modernc.org/libc v1.22.3 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/opt v0.1.3 // indirect + modernc.org/strutil v1.1.3 // indirect + modernc.org/token v1.0.1 // indirect +) diff --git a/trunk/go.sum b/trunk/go.sum new file mode 100644 index 0000000..090246d --- /dev/null +++ b/trunk/go.sum @@ -0,0 +1,105 @@ +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao= +git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U= +git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE= +git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y= +github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc= +gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw= +modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0= +modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw= +modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY= +modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk= +modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= +modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY= +modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow= +modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI= +modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY= +modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw= +modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws= +modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg= +modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE= diff --git a/trunk/irc.go b/trunk/irc.go new file mode 100644 index 0000000..1799e32 --- /dev/null +++ b/trunk/irc.go @@ -0,0 +1,811 @@ +package suika + +import ( + "fmt" + "sort" + "strings" + "time" + "unicode" + "unicode/utf8" + + "gopkg.in/irc.v3" +) + +const ( + rpl_statsping = "246" + rpl_localusers = "265" + rpl_globalusers = "266" + rpl_creationtime = "329" + rpl_topicwhotime = "333" + rpl_whospcrpl = "354" + rpl_whoisaccount = "330" + err_invalidcapcmd = "410" +) + +const ( + maxMessageLength = 512 + maxMessageParams = 15 + maxSASLLength = 400 +) + +// The server-time layout, as defined in the IRCv3 spec. +const serverTimeLayout = "2006-01-02T15:04:05.000Z" + +func formatServerTime(t time.Time) string { + return t.UTC().Format(serverTimeLayout) +} + +type userModes string + +func (ms userModes) Has(c byte) bool { + return strings.IndexByte(string(ms), c) >= 0 +} + +func (ms *userModes) Add(c byte) { + if !ms.Has(c) { + *ms += userModes(c) + } +} + +func (ms *userModes) Del(c byte) { + i := strings.IndexByte(string(*ms), c) + if i >= 0 { + *ms = (*ms)[:i] + (*ms)[i+1:] + } +} + +func (ms *userModes) Apply(s string) error { + var plusMinus byte + for i := 0; i < len(s); i++ { + switch c := s[i]; c { + case '+', '-': + plusMinus = c + default: + switch plusMinus { + case '+': + ms.Add(c) + case '-': + ms.Del(c) + default: + return fmt.Errorf("malformed modestring %q: missing plus/minus", s) + } + } + } + return nil +} + +type channelModeType byte + +// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message +const ( + // modes that add or remove an address to or from a list + modeTypeA channelModeType = iota + // modes that change a setting on a channel, and must always have a parameter + modeTypeB + // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset + modeTypeC + // modes that change a setting on a channel, and must not have a parameter + modeTypeD +) + +var stdChannelModes = map[byte]channelModeType{ + 'b': modeTypeA, // ban list + 'e': modeTypeA, // ban exception list + 'I': modeTypeA, // invite exception list + 'k': modeTypeB, // channel key + 'l': modeTypeC, // channel user limit + 'i': modeTypeD, // channel is invite-only + 'm': modeTypeD, // channel is moderated + 'n': modeTypeD, // channel has no external messages + 's': modeTypeD, // channel is secret + 't': modeTypeD, // channel has protected topic +} + +type channelModes map[byte]string + +// applyChannelModes parses a mode string and mode arguments from a MODE message, +// and applies the corresponding channel mode and user membership changes on that channel. +// +// If ch.modes is nil, channel modes are not updated. +// +// needMarshaling is a list of indexes of mode arguments that represent entities +// that must be marshaled when sent downstream. +func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) { + needMarshaling = make(map[int]struct{}, len(arguments)) + nextArgument := 0 + var plusMinus byte +outer: + for i := 0; i < len(modeStr); i++ { + mode := modeStr[i] + if mode == '+' || mode == '-' { + plusMinus = mode + continue + } + if plusMinus != '+' && plusMinus != '-' { + return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr) + } + + for _, membership := range ch.conn.availableMemberships { + if membership.Mode == mode { + if nextArgument >= len(arguments) { + return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) + } + member := arguments[nextArgument] + m := ch.Members.Value(member) + if m != nil { + if plusMinus == '+' { + m.Add(ch.conn.availableMemberships, membership) + } else { + // TODO: for upstreams without multi-prefix, query the user modes again + m.Remove(membership) + } + } + needMarshaling[nextArgument] = struct{}{} + nextArgument++ + continue outer + } + } + + mt, ok := ch.conn.availableChannelModes[mode] + if !ok { + continue + } + if mt == modeTypeA { + nextArgument++ + } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') { + if plusMinus == '+' { + var argument string + // some sentitive arguments (such as channel keys) can be omitted for privacy + // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages) + if nextArgument < len(arguments) { + argument = arguments[nextArgument] + } + if ch.modes != nil { + ch.modes[mode] = argument + } + } else { + delete(ch.modes, mode) + } + nextArgument++ + } else if mt == modeTypeC || mt == modeTypeD { + if plusMinus == '+' { + if ch.modes != nil { + ch.modes[mode] = "" + } + } else { + delete(ch.modes, mode) + } + } + } + return needMarshaling, nil +} + +func (cm channelModes) Format() (modeString string, parameters []string) { + var modesWithValues strings.Builder + var modesWithoutValues strings.Builder + parameters = make([]string, 0, 16) + for mode, value := range cm { + if value != "" { + modesWithValues.WriteString(string(mode)) + parameters = append(parameters, value) + } else { + modesWithoutValues.WriteString(string(mode)) + } + } + modeString = "+" + modesWithValues.String() + modesWithoutValues.String() + return +} + +const stdChannelTypes = "#&+!" + +type channelStatus byte + +const ( + channelPublic channelStatus = '=' + channelSecret channelStatus = '@' + channelPrivate channelStatus = '*' +) + +func parseChannelStatus(s string) (channelStatus, error) { + if len(s) > 1 { + return 0, fmt.Errorf("invalid channel status %q: more than one character", s) + } + switch cs := channelStatus(s[0]); cs { + case channelPublic, channelSecret, channelPrivate: + return cs, nil + default: + return 0, fmt.Errorf("invalid channel status %q: unknown status", s) + } +} + +type membership struct { + Mode byte + Prefix byte +} + +var stdMemberships = []membership{ + {'q', '~'}, // founder + {'a', '&'}, // protected + {'o', '@'}, // operator + {'h', '%'}, // halfop + {'v', '+'}, // voice +} + +// memberships always sorted by descending membership rank +type memberships []membership + +func (m *memberships) Add(availableMemberships []membership, newMembership membership) { + l := *m + i := 0 + for _, availableMembership := range availableMemberships { + if i >= len(l) { + break + } + if l[i] == availableMembership { + if availableMembership == newMembership { + // we already have this membership + return + } + i++ + continue + } + if availableMembership == newMembership { + break + } + } + // insert newMembership at i + l = append(l, membership{}) + copy(l[i+1:], l[i:]) + l[i] = newMembership + *m = l +} + +func (m *memberships) Remove(oldMembership membership) { + l := *m + for i, currentMembership := range l { + if currentMembership == oldMembership { + *m = append(l[:i], l[i+1:]...) + return + } + } +} + +func (m memberships) Format(dc *downstreamConn) string { + if !dc.caps["multi-prefix"] { + if len(m) == 0 { + return "" + } + return string(m[0].Prefix) + } + prefixes := make([]byte, len(m)) + for i, membership := range m { + prefixes[i] = membership.Prefix + } + return string(prefixes) +} + +func parseMessageParams(msg *irc.Message, out ...*string) error { + if len(msg.Params) < len(out) { + return newNeedMoreParamsError(msg.Command) + } + for i := range out { + if out[i] != nil { + *out[i] = msg.Params[i] + } + } + return nil +} + +func copyClientTags(tags irc.Tags) irc.Tags { + t := make(irc.Tags, len(tags)) + for k, v := range tags { + if strings.HasPrefix(k, "+") { + t[k] = v + } + } + return t +} + +type batch struct { + Type string + Params []string + Outer *batch // if not-nil, this batch is nested in Outer + Label string +} + +func join(channels, keys []string) []*irc.Message { + // Put channels with a key first + js := joinSorter{channels, keys} + sort.Sort(&js) + + // Two spaces because there are three words (JOIN, channels and keys) + maxLength := maxMessageLength - (len("JOIN") + 2) + + var msgs []*irc.Message + var channelsBuf, keysBuf strings.Builder + for i, channel := range channels { + key := keys[i] + + n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel) + if key != "" { + n += 1 + len(key) + } + + if channelsBuf.Len() > 0 && n > maxLength { + // No room for the new channel in this message + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + channelsBuf.Reset() + keysBuf.Reset() + } + + if channelsBuf.Len() > 0 { + channelsBuf.WriteByte(',') + } + channelsBuf.WriteString(channel) + if key != "" { + if keysBuf.Len() > 0 { + keysBuf.WriteByte(',') + } + keysBuf.WriteString(key) + } + } + if channelsBuf.Len() > 0 { + params := []string{channelsBuf.String()} + if keysBuf.Len() > 0 { + params = append(params, keysBuf.String()) + } + msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params}) + } + + return msgs +} + +func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message { + maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text + + var msgs []*irc.Message + for len(tokens) > 0 { + var msgTokens []string + if len(tokens) > maxTokens { + msgTokens = tokens[:maxTokens] + tokens = tokens[maxTokens:] + } else { + msgTokens = tokens + tokens = nil + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ISUPPORT, + Params: append(append([]string{nick}, msgTokens...), "are supported"), + }) + } + + return msgs +} + +func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message { + var msgs []*irc.Message + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTDSTART, + Params: []string{nick, fmt.Sprintf("- Message of the Day -")}, + }) + + for _, l := range strings.Split(motd, "\n") { + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_MOTD, + Params: []string{nick, l}, + }) + } + + msgs = append(msgs, &irc.Message{ + Prefix: prefix, + Command: irc.RPL_ENDOFMOTD, + Params: []string{nick, "End of /MOTD command."}, + }) + + return msgs +} + +func generateMonitor(subcmd string, targets []string) []*irc.Message { + maxLength := maxMessageLength - len("MONITOR "+subcmd+" ") + + var msgs []*irc.Message + var buf []string + n := 0 + for _, target := range targets { + if n+len(target)+1 > maxLength { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + buf = buf[:0] + n = 0 + } + + buf = append(buf, target) + n += len(target) + 1 + } + + if len(buf) > 0 { + msgs = append(msgs, &irc.Message{ + Command: "MONITOR", + Params: []string{subcmd, strings.Join(buf, ",")}, + }) + } + + return msgs +} + +type joinSorter struct { + channels []string + keys []string +} + +func (js *joinSorter) Len() int { + return len(js.channels) +} + +func (js *joinSorter) Less(i, j int) bool { + if (js.keys[i] != "") != (js.keys[j] != "") { + // Only one of the channels has a key + return js.keys[i] != "" + } + return js.channels[i] < js.channels[j] +} + +func (js *joinSorter) Swap(i, j int) { + js.channels[i], js.channels[j] = js.channels[j], js.channels[i] + js.keys[i], js.keys[j] = js.keys[j], js.keys[i] +} + +// parseCTCPMessage parses a CTCP message. CTCP is defined in +// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02 +func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) { + if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 { + return "", "", false + } + text := msg.Params[1] + + if !strings.HasPrefix(text, "\x01") { + return "", "", false + } + text = strings.Trim(text, "\x01") + + words := strings.SplitN(text, " ", 2) + cmd = strings.ToUpper(words[0]) + if len(words) > 1 { + params = words[1] + } + + return cmd, params, true +} + +type casemapping func(string) string + +func casemapNone(name string) string { + return name +} + +// CasemapASCII of name is the canonical representation of name according to the +// ascii casemapping. +func casemapASCII(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } + } + return string(nameBytes) +} + +// casemapRFC1459 of name is the canonical representation of name according to the +// rfc1459 casemapping. +func casemapRFC1459(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } else if r == '~' { + nameBytes[i] = '^' + } + } + return string(nameBytes) +} + +// casemapRFC1459Strict of name is the canonical representation of name +// according to the rfc1459-strict casemapping. +func casemapRFC1459Strict(name string) string { + nameBytes := []byte(name) + for i, r := range nameBytes { + if 'A' <= r && r <= 'Z' { + nameBytes[i] = r + 'a' - 'A' + } else if r == '{' { + nameBytes[i] = '[' + } else if r == '}' { + nameBytes[i] = ']' + } else if r == '\\' { + nameBytes[i] = '|' + } + } + return string(nameBytes) +} + +func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) { + switch tokenValue { + case "ascii": + casemap = casemapASCII + case "rfc1459": + casemap = casemapRFC1459 + case "rfc1459-strict": + casemap = casemapRFC1459Strict + default: + return nil, false + } + return casemap, true +} + +func partialCasemap(higher casemapping, name string) string { + nameFullyCM := []byte(higher(name)) + nameBytes := []byte(name) + for i, r := range nameBytes { + if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') { + nameBytes[i] = nameFullyCM[i] + } + } + return string(nameBytes) +} + +type casemapMap struct { + innerMap map[string]casemapEntry + casemap casemapping +} + +type casemapEntry struct { + originalKey string + value interface{} +} + +func newCasemapMap(size int) casemapMap { + return casemapMap{ + innerMap: make(map[string]casemapEntry, size), + casemap: casemapNone, + } +} + +func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return "", false + } + return entry.originalKey, true +} + +func (cm *casemapMap) Has(name string) bool { + _, ok := cm.innerMap[cm.casemap(name)] + return ok +} + +func (cm *casemapMap) Len() int { + return len(cm.innerMap) +} + +func (cm *casemapMap) SetValue(name string, value interface{}) { + nameCM := cm.casemap(name) + entry, ok := cm.innerMap[nameCM] + if !ok { + cm.innerMap[nameCM] = casemapEntry{ + originalKey: name, + value: value, + } + return + } + entry.value = value + cm.innerMap[nameCM] = entry +} + +func (cm *casemapMap) Delete(name string) { + delete(cm.innerMap, cm.casemap(name)) +} + +func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { + cm.casemap = newCasemap + newInnerMap := make(map[string]casemapEntry, len(cm.innerMap)) + for _, entry := range cm.innerMap { + newInnerMap[cm.casemap(entry.originalKey)] = entry + } + cm.innerMap = newInnerMap +} + +type upstreamChannelCasemapMap struct{ casemapMap } + +func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*upstreamChannel) +} + +type channelCasemapMap struct{ casemapMap } + +func (cm *channelCasemapMap) Value(name string) *Channel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*Channel) +} + +type membershipsCasemapMap struct{ casemapMap } + +func (cm *membershipsCasemapMap) Value(name string) *memberships { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*memberships) +} + +type deliveredCasemapMap struct{ casemapMap } + +func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(deliveredClientMap) +} + +type monitorCasemapMap struct{ casemapMap } + +func (cm *monitorCasemapMap) Value(name string) (online bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return false + } + return entry.value.(bool) +} + +func isWordBoundary(r rune) bool { + switch r { + case '-', '_', '|': // inspired from weechat.look.highlight_regex + return false + default: + return !unicode.IsLetter(r) && !unicode.IsNumber(r) + } +} + +func isHighlight(text, nick string) bool { + for { + i := strings.Index(text, nick) + if i < 0 { + return false + } + + left, _ := utf8.DecodeLastRuneInString(text[:i]) + right, _ := utf8.DecodeRuneInString(text[i+len(nick):]) + if isWordBoundary(left) && isWordBoundary(right) { + return true + } + + text = text[i+len(nick):] + } +} + +// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound. +// The zero time is returned on error. +func parseChatHistoryBound(param string) time.Time { + parts := strings.SplitN(param, "=", 2) + if len(parts) != 2 { + return time.Time{} + } + switch parts[0] { + case "timestamp": + timestamp, err := time.Parse(serverTimeLayout, parts[1]) + if err != nil { + return time.Time{} + } + return timestamp + default: + return time.Time{} + } +} + +// whoxFields is the list of all WHOX field letters, by order of appearance in +// RPL_WHOSPCRPL messages. +var whoxFields = []byte("tcuihsnfdlaor") + +type whoxInfo struct { + Token string + Username string + Hostname string + Server string + Nickname string + Flags string + Account string + Realname string +} + +func (info *whoxInfo) get(field byte) string { + switch field { + case 't': + return info.Token + case 'c': + return "*" + case 'u': + return info.Username + case 'i': + return "255.255.255.255" + case 'h': + return info.Hostname + case 's': + return info.Server + case 'n': + return info.Nickname + case 'f': + return info.Flags + case 'd': + return "0" + case 'l': // idle time + return "0" + case 'a': + account := "0" // WHOX uses "0" to mean "no account" + if info.Account != "" && info.Account != "*" { + account = info.Account + } + return account + case 'o': + return "0" + case 'r': + return info.Realname + } + return "" +} + +func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message { + if fields == "" { + return &irc.Message{ + Prefix: prefix, + Command: irc.RPL_WHOREPLY, + Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname}, + } + } + + fieldSet := make(map[byte]bool) + for i := 0; i < len(fields); i++ { + fieldSet[fields[i]] = true + } + + var values []string + for _, field := range whoxFields { + if !fieldSet[field] { + continue + } + values = append(values, info.get(field)) + } + + return &irc.Message{ + Prefix: prefix, + Command: rpl_whospcrpl, + Params: append([]string{nick}, values...), + } +} + +var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C") + +func encodeISUPPORT(s string) string { + return isupportEncoder.Replace(s) +} diff --git a/trunk/irc_test.go b/trunk/irc_test.go new file mode 100644 index 0000000..8757bc8 --- /dev/null +++ b/trunk/irc_test.go @@ -0,0 +1,34 @@ +package suika + +import ( + "testing" +) + +func TestIsHighlight(t *testing.T) { + nick := "SojuUser" + testCases := []struct { + name string + text string + hl bool + }{ + {"noContains", "hi there Soju User!", false}, + {"middle", "hi there SojuUser!", true}, + {"start", "SojuUser: how are you doing?", true}, + {"end", "maybe ask SojuUser", true}, + {"inWord", "but OtherSojuUserSan is a different nick", false}, + {"startWord", "and OtherSojuUser is another different nick", false}, + {"endWord", "and SojuUserSan is yet a different nick", false}, + {"underscore", "and SojuUser_san has nothing to do with me", false}, + {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false}, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + hl := isHighlight(tc.text, nick) + if hl != tc.hl { + t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl) + } + }) + } +} diff --git a/trunk/msgstore.go b/trunk/msgstore.go new file mode 100644 index 0000000..ed57429 --- /dev/null +++ b/trunk/msgstore.go @@ -0,0 +1,123 @@ +package suika + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +// messageStore is a per-user store for IRC messages. +type messageStore interface { + Close() error + // LastMsgID queries the last message ID for the given network, entity and + // date. The message ID returned may not refer to a valid message, but can be + // used in history queries. + LastMsgID(network *Network, entity string, t time.Time) (string, error) + // LoadLatestID queries the latest non-event messages for the given network, + // entity and date, up to a count of limit messages, sorted from oldest to newest. + LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) + Append(network *Network, entity string, msg *irc.Message) (id string, err error) +} + +type chatHistoryTarget struct { + Name string + LatestMessage time.Time +} + +// chatHistoryMessageStore is a message store that supports chat history +// operations. +type chatHistoryMessageStore interface { + messageStore + + // ListTargets lists channels and nicknames by time of the latest message. + // It returns up to limit targets, starting from start and ending on end, + // both excluded. end may be before or after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) + // LoadBeforeTime loads up to limit messages before start down to end. The + // returned messages must be between and excluding the provided bounds. + // end is before start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) + // LoadBeforeTime loads up to limit messages after start up to end. The + // returned messages must be between and excluding the provided bounds. + // end is after start. + // If events is false, only PRIVMSG/NOTICE messages are considered. + LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) +} + +type msgIDType uint + +const ( + msgIDNone msgIDType = iota + msgIDMemory + msgIDFS +) + +const msgIDVersion uint = 0 + +type msgIDHeader struct { + Version uint + Network bare.Int + Target string + Type msgIDType +} + +type msgIDBody interface { + msgIDType() msgIDType +} + +func formatMsgID(netID int64, target string, body msgIDBody) string { + var buf bytes.Buffer + w := bare.NewWriter(&buf) + + header := msgIDHeader{ + Version: msgIDVersion, + Network: bare.Int(netID), + Target: target, + Type: body.msgIDType(), + } + if err := bare.MarshalWriter(w, &header); err != nil { + panic(err) + } + if err := bare.MarshalWriter(w, body); err != nil { + panic(err) + } + return base64.RawURLEncoding.EncodeToString(buf.Bytes()) +} + +func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + r := bare.NewReader(bytes.NewReader(b)) + + var header msgIDHeader + if err := bare.UnmarshalBareReader(r, &header); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + + if header.Version != msgIDVersion { + return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion) + } + + if body != nil { + typ := body.msgIDType() + if header.Type != typ { + return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ) + } + + if err := bare.UnmarshalBareReader(r, body); err != nil { + return 0, "", fmt.Errorf("invalid internal message ID: %v", err) + } + } + + return int64(header.Network), header.Target, nil +} diff --git a/trunk/msgstore_fs.go b/trunk/msgstore_fs.go new file mode 100644 index 0000000..90bd76a --- /dev/null +++ b/trunk/msgstore_fs.go @@ -0,0 +1,698 @@ +package suika + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const ( + fsMessageStoreMaxFiles = 20 + fsMessageStoreMaxTries = 100 +) + +func escapeFilename(unsafe string) (safe string) { + if unsafe == "." { + return "-" + } else if unsafe == ".." { + return "--" + } else { + return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe) + } +} + +type date struct { + Year, Month, Day int +} + +func newDate(t time.Time) date { + year, month, day := t.Date() + return date{year, int(month), day} +} + +func (d date) Time() time.Time { + return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local) +} + +type fsMsgID struct { + Date date + Offset bare.Int +} + +func (fsMsgID) msgIDType() msgIDType { + return msgIDFS +} + +func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) { + var id fsMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", time.Time{}, 0, err + } + return netID, entity, id.Date.Time(), int64(id.Offset), nil +} + +func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string { + id := fsMsgID{ + Date: newDate(t), + Offset: bare.Int(offset), + } + return formatMsgID(netID, entity, &id) +} + +type fsMessageStoreFile struct { + *os.File + lastUse time.Time +} + +// fsMessageStore is a per-user on-disk store for IRC messages. +// +// It mimicks the ZNC log layout and format. See the ZNC source: +// https://github.com/znc/znc/blob/master/modules/log.cpp +type fsMessageStore struct { + root string + user *User + + // Write-only files used by Append + files map[string]*fsMessageStoreFile // indexed by entity +} + +var _ messageStore = (*fsMessageStore)(nil) +var _ chatHistoryMessageStore = (*fsMessageStore)(nil) + +func newFSMessageStore(root string, user *User) *fsMessageStore { + return &fsMessageStore{ + root: filepath.Join(root, escapeFilename(user.Username)), + user: user, + files: make(map[string]*fsMessageStoreFile), + } +} + +func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string { + year, month, day := t.Date() + filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) + return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) +} + +// nextMsgID queries the message ID for the next message to be written to f. +func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) { + offset, err := f.Seek(0, io.SeekEnd) + if err != nil { + return "", fmt.Errorf("failed to query next FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, offset), nil +} + +func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + p := ms.logPath(network, entity, t) + fi, err := os.Stat(p) + if os.IsNotExist(err) { + return formatFSMsgID(network.ID, entity, t, -1), nil + } else if err != nil { + return "", fmt.Errorf("failed to query last FS message ID: %v", err) + } + return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil +} + +func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + s := formatMessage(msg) + if s == "" { + return "", nil + } + + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(serverTimeLayout, string(tag)) + if err != nil { + return "", fmt.Errorf("failed to parse message time tag: %v", err) + } + t = t.In(time.Local) + } else { + t = time.Now() + } + + f := ms.files[entity] + + // TODO: handle non-monotonic clock behaviour + path := ms.logPath(network, entity, t) + if f == nil || f.Name() != path { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0750); err != nil { + return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err) + } + + ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640) + if err != nil { + return "", fmt.Errorf("failed to open message log file %q: %v", path, err) + } + + if f != nil { + f.Close() + } + f = &fsMessageStoreFile{File: ff} + ms.files[entity] = f + } + + f.lastUse = time.Now() + + if len(ms.files) > fsMessageStoreMaxFiles { + entities := make([]string, 0, len(ms.files)) + for name := range ms.files { + entities = append(entities, name) + } + sort.Slice(entities, func(i, j int) bool { + a, b := entities[i], entities[j] + return ms.files[a].lastUse.Before(ms.files[b].lastUse) + }) + entities = entities[0 : len(entities)-fsMessageStoreMaxFiles] + for _, name := range entities { + ms.files[name].Close() + delete(ms.files, name) + } + } + + msgID, err := nextFSMsgID(network, entity, t, f.File) + if err != nil { + return "", fmt.Errorf("failed to generate message ID: %v", err) + } + + _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s) + if err != nil { + return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err) + } + + return msgID, nil +} + +func (ms *fsMessageStore) Close() error { + var closeErr error + for _, f := range ms.files { + if err := f.Close(); err != nil { + closeErr = fmt.Errorf("failed to close message store: %v", err) + } + } + return closeErr +} + +// formatMessage formats a message log line. It assumes a well-formed IRC +// message. +func formatMessage(msg *irc.Message) string { + switch strings.ToUpper(msg.Command) { + case "NICK": + return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0]) + case "JOIN": + return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host) + case "PART": + var reason string + if len(msg.Params) > 1 { + reason = msg.Params[1] + } + return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "KICK": + nick := msg.Params[1] + var reason string + if len(msg.Params) > 2 { + reason = msg.Params[2] + } + return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason) + case "QUIT": + var reason string + if len(msg.Params) > 0 { + reason = msg.Params[0] + } + return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason) + case "TOPIC": + var topic string + if len(msg.Params) > 1 { + topic = msg.Params[1] + } + return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic) + case "MODE": + return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " ")) + case "NOTICE": + return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1]) + case "PRIVMSG": + if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" { + return fmt.Sprintf("* %s %s", msg.Prefix.Name, params) + } else { + return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1]) + } + default: + return "" + } +} + +func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { + var hour, minute, second int + _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) + if err != nil { + return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err) + } + line = line[11:] + + var cmd string + var prefix *irc.Prefix + var params []string + if events && strings.HasPrefix(line, "*** ") { + parts := strings.SplitN(line[4:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + switch parts[0] { + case "Joins:", "Parts:", "Quits:": + args := strings.SplitN(parts[1], " ", 3) + if len(args) < 2 { + return nil, time.Time{}, nil + } + nick := args[0] + mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + maskParts := strings.SplitN(mask, "@", 2) + if len(maskParts) != 2 { + return nil, time.Time{}, nil + } + prefix = &irc.Prefix{ + Name: nick, + User: maskParts[0], + Host: maskParts[1], + } + var reason string + if len(args) > 2 { + reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")") + } + switch parts[0] { + case "Joins:": + cmd = "JOIN" + params = []string{entity} + case "Parts:": + cmd = "PART" + if reason != "" { + params = []string{entity, reason} + } else { + params = []string{entity} + } + case "Quits:": + cmd = "QUIT" + if reason != "" { + params = []string{reason} + } + } + default: + nick := parts[0] + rem := parts[1] + if r := strings.TrimPrefix(rem, "is now known as "); r != rem { + cmd = "NICK" + prefix = &irc.Prefix{ + Name: nick, + } + params = []string{r} + } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem { + args := strings.SplitN(r, " ", 2) + if len(args) != 2 { + return nil, time.Time{}, nil + } + cmd = "KICK" + prefix = &irc.Prefix{ + Name: args[0], + } + reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")") + params = []string{entity, nick} + if reason != "" { + params = append(params, reason) + } + } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem { + cmd = "TOPIC" + prefix = &irc.Prefix{ + Name: nick, + } + topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'") + params = []string{entity, topic} + } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem { + cmd = "MODE" + prefix = &irc.Prefix{ + Name: nick, + } + params = append([]string{entity}, strings.Split(r, " ")...) + } else { + return nil, time.Time{}, nil + } + } + } else { + var sender, text string + if strings.HasPrefix(line, "<") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[1:], "> ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "-") { + cmd = "NOTICE" + parts := strings.SplitN(line[1:], "- ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], parts[1] + } else if strings.HasPrefix(line, "* ") { + cmd = "PRIVMSG" + parts := strings.SplitN(line[2:], " ", 2) + if len(parts) != 2 { + return nil, time.Time{}, nil + } + sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01" + } else { + return nil, time.Time{}, nil + } + + prefix = &irc.Prefix{Name: sender} + if entity == sender { + // This is a direct message from a user to us. We don't store own + // our nickname in the logs, so grab it from the network settings. + // Not very accurate since this may not match our nick at the time + // the message was received, but we can't do a lot better. + entity = GetNick(ms.user, network) + } + params = []string{entity, text} + } + + year, month, day := ref.Date() + t := time.Date(year, month, day, hour, minute, second, 0, time.Local) + + msg := &irc.Message{ + Tags: map[string]irc.TagValue{ + "time": irc.TagValue(formatServerTime(t)), + }, + Prefix: prefix, + Command: cmd, + Params: params, + } + return msg, t, nil +} + +func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages before ref: %v", err) + } + defer f.Close() + + historyRing := make([]*irc.Message, limit) + cur := 0 + + sc := bufio.NewScanner(f) + + if afterOffset >= 0 { + if _, err := f.Seek(afterOffset, io.SeekStart); err != nil { + return nil, nil + } + sc.Scan() // skip till next newline + } + + for sc.Scan() { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(end) { + continue + } else if !t.Before(ref) { + break + } + + historyRing[cur%limit] = msg + cur++ + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err()) + } + + n := limit + if cur < limit { + n = cur + } + start := (cur - n + limit) % limit + + if start+n <= limit { // ring doesnt wrap + return historyRing[start : start+n], nil + } else { // ring wraps + history := make([]*irc.Message, n) + r := copy(history, historyRing[start:]) + copy(history[r:], historyRing[:n-r]) + return history, nil + } +} + +func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) { + path := ms.logPath(network, entity, ref) + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to parse messages after ref: %v", err) + } + defer f.Close() + + var history []*irc.Message + sc := bufio.NewScanner(f) + for sc.Scan() && len(history) < limit { + msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events) + if err != nil { + return nil, err + } else if msg == nil || !t.After(ref) { + continue + } else if !t.Before(end) { + break + } + + history = append(history, msg) + } + if sc.Err() != nil { + return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err()) + } + + return history, nil +} + +func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + history := make([]*irc.Message, limit) + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) { + buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { + start = start.In(time.Local) + end = end.In(time.Local) + var history []*irc.Message + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) { + buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + history = append(history, buf...) + remaining -= len(buf) + year, month, day := start.Date() + start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location()) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + return history, nil +} + +func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + var afterTime time.Time + var afterOffset int64 + if id != "" { + var idNet int64 + var idEntity string + var err error + idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id) + if err != nil { + return nil, err + } + if idNet != network.ID || idEntity != entity { + return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity") + } + } + + history := make([]*irc.Message, limit) + t := time.Now() + remaining := limit + tries := 0 + for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) { + var offset int64 = -1 + if afterOffset >= 0 && truncateDay(t).Equal(afterTime) { + offset = afterOffset + } + + buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset) + if err != nil { + return nil, err + } + if len(buf) == 0 { + tries++ + } else { + tries = 0 + } + copy(history[remaining-len(buf):], buf) + remaining -= len(buf) + year, month, day := t.Date() + t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + return history[remaining:], nil +} + +func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { + start = start.In(time.Local) + end = end.In(time.Local) + rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) + root, err := os.Open(rootPath) + if os.IsNotExist(err) { + return nil, nil + } else if err != nil { + return nil, err + } + + // The returned targets are escaped, and there is no way to un-escape + // TODO: switch to ReadDir (Go 1.16+) + targetNames, err := root.Readdirnames(0) + root.Close() + if err != nil { + return nil, err + } + + var targets []chatHistoryTarget + for _, target := range targetNames { + // target is already escaped here + targetPath := filepath.Join(rootPath, target) + targetDir, err := os.Open(targetPath) + if err != nil { + return nil, err + } + + entries, err := targetDir.Readdir(0) + targetDir.Close() + if err != nil { + return nil, err + } + + // We use mtime here, which may give imprecise or incorrect results + var t time.Time + for _, entry := range entries { + if entry.ModTime().After(t) { + t = entry.ModTime() + } + } + + // The timestamps we get from logs have second granularity + t = truncateSecond(t) + + // Filter out targets that don't fullfil the time bounds + if !isTimeBetween(t, start, end) { + continue + } + + targets = append(targets, chatHistoryTarget{ + Name: target, + LatestMessage: t, + }) + + if err := ctx.Err(); err != nil { + return nil, err + } + } + + // Sort targets by latest message time, backwards or forwards depending on + // the order of the time bounds + sort.Slice(targets, func(i, j int) bool { + t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage + if start.Before(end) { + return t1.Before(t2) + } else { + return !t1.Before(t2) + } + }) + + // Truncate the result if necessary + if len(targets) > limit { + targets = targets[:limit] + } + + return targets, nil +} + +func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error { + oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) + newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) + // Avoid loosing data by overwriting an existing directory + if _, err := os.Stat(newDir); err == nil { + return fmt.Errorf("destination %q already exists", newDir) + } + return os.Rename(oldDir, newDir) +} + +func truncateDay(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, 0, 0, 0, 0, t.Location()) +} + +func truncateSecond(t time.Time) time.Time { + year, month, day := t.Date() + return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location()) +} + +func isTimeBetween(t, start, end time.Time) bool { + if end.Before(start) { + end, start = start, end + } + return start.Before(t) && t.Before(end) +} diff --git a/trunk/msgstore_memory.go b/trunk/msgstore_memory.go new file mode 100644 index 0000000..25ab953 --- /dev/null +++ b/trunk/msgstore_memory.go @@ -0,0 +1,160 @@ +package suika + +import ( + "context" + "fmt" + "time" + + "git.sr.ht/~sircmpwn/go-bare" + "gopkg.in/irc.v3" +) + +const messageRingBufferCap = 4096 + +type memoryMsgID struct { + Seq bare.Uint +} + +func (memoryMsgID) msgIDType() msgIDType { + return msgIDMemory +} + +func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) { + var id memoryMsgID + netID, entity, err = parseMsgID(s, &id) + if err != nil { + return 0, "", 0, err + } + return netID, entity, uint64(id.Seq), nil +} + +func formatMemoryMsgID(netID int64, entity string, seq uint64) string { + id := memoryMsgID{bare.Uint(seq)} + return formatMsgID(netID, entity, &id) +} + +type ringBufferKey struct { + networkID int64 + entity string +} + +type memoryMessageStore struct { + buffers map[ringBufferKey]*messageRingBuffer +} + +var _ messageStore = (*memoryMessageStore)(nil) + +func newMemoryMessageStore() *memoryMessageStore { + return &memoryMessageStore{ + buffers: make(map[ringBufferKey]*messageRingBuffer), + } +} + +func (ms *memoryMessageStore) Close() error { + ms.buffers = nil + return nil +} + +func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer { + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + return rb + } + rb := newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + return rb +} + +func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { + var seq uint64 + k := ringBufferKey{networkID: network.ID, entity: entity} + if rb, ok := ms.buffers[k]; ok { + seq = rb.cur + } + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { + switch msg.Command { + case "PRIVMSG", "NOTICE": + // Only append these messages, because LoadLatestID shouldn't return + // other kinds of message. + default: + return "", nil + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + rb = newMessageRingBuffer(messageRingBufferCap) + ms.buffers[k] = rb + } + + seq := rb.Append(msg) + return formatMemoryMsgID(network.ID, entity, seq), nil +} + +func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { + _, _, seq, err := parseMemoryMsgID(id) + if err != nil { + return nil, err + } + + k := ringBufferKey{networkID: network.ID, entity: entity} + rb, ok := ms.buffers[k] + if !ok { + return nil, nil + } + + return rb.LoadLatestSeq(seq, limit) +} + +type messageRingBuffer struct { + buf []*irc.Message + cur uint64 +} + +func newMessageRingBuffer(capacity int) *messageRingBuffer { + return &messageRingBuffer{ + buf: make([]*irc.Message, capacity), + cur: 1, + } +} + +func (rb *messageRingBuffer) cap() uint64 { + return uint64(len(rb.buf)) +} + +func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 { + seq := rb.cur + i := int(seq % rb.cap()) + rb.buf[i] = msg + rb.cur++ + return seq +} + +func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) { + if seq > rb.cur { + return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur) + } else if seq == rb.cur { + return nil, nil + } + + // The query excludes the message with the sequence number seq + diff := rb.cur - seq - 1 + if diff > rb.cap() { + // We dropped diff - cap entries + diff = rb.cap() + } + if int(diff) > limit { + diff = uint64(limit) + } + + l := make([]*irc.Message, int(diff)) + for i := 0; i < int(diff); i++ { + j := int((rb.cur - diff + uint64(i)) % rb.cap()) + l[i] = rb.buf[j] + } + + return l, nil +} diff --git a/trunk/net_go113.go b/trunk/net_go113.go new file mode 100644 index 0000000..24bde7f --- /dev/null +++ b/trunk/net_go113.go @@ -0,0 +1,12 @@ +//go:build !go1.16 +// +build !go1.16 + +package suika + +import ( + "strings" +) + +func isErrClosed(err error) bool { + return err != nil && strings.Contains(err.Error(), "use of closed network connection") +} diff --git a/trunk/net_go116.go b/trunk/net_go116.go new file mode 100644 index 0000000..ef1256a --- /dev/null +++ b/trunk/net_go116.go @@ -0,0 +1,13 @@ +//go:build go1.16 +// +build go1.16 + +package suika + +import ( + "errors" + "net" +) + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) +} diff --git a/trunk/rate.go b/trunk/rate.go new file mode 100644 index 0000000..a14bce2 --- /dev/null +++ b/trunk/rate.go @@ -0,0 +1,40 @@ +package suika + +import ( + "math/rand" + "time" +) + +// backoffer implements a simple exponential backoff. +type backoffer struct { + min, max, jitter time.Duration + n int64 +} + +func newBackoffer(min, max, jitter time.Duration) *backoffer { + return &backoffer{min: min, max: max, jitter: jitter} +} + +func (b *backoffer) Reset() { + b.n = 0 +} + +func (b *backoffer) Next() time.Duration { + if b.n == 0 { + b.n = 1 + return 0 + } + + d := time.Duration(b.n) * b.min + if d > b.max { + d = b.max + } else { + b.n *= 2 + } + + if b.jitter != 0 { + d += time.Duration(rand.Int63n(int64(b.jitter))) + } + + return d +} diff --git a/trunk/rc.d/freebsd-rc.d b/trunk/rc.d/freebsd-rc.d new file mode 100755 index 0000000..0fed078 --- /dev/null +++ b/trunk/rc.d/freebsd-rc.d @@ -0,0 +1,29 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +desc="A drunk IRC bouncer" +rcvar="suika_enable" + +: ${suika_user="ircd"} + +command="%%PREFIX%%/bin/suika" +pidfile="/var/run/suika.pid" +required_files="%%PREFIX%%/etc/suika/config" + +start_cmd="suika_start" + +suika_start() { + /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files} +} + +load_rc_config "$name" +run_rc_command "$1" diff --git a/trunk/rc.d/immortal.yml b/trunk/rc.d/immortal.yml new file mode 100644 index 0000000..c7a8090 --- /dev/null +++ b/trunk/rc.d/immortal.yml @@ -0,0 +1,3 @@ +# $TheSupernovaDuo$ +cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +user: ircd diff --git a/trunk/rc.d/netbsd-rc.d b/trunk/rc.d/netbsd-rc.d new file mode 100644 index 0000000..3fef470 --- /dev/null +++ b/trunk/rc.d/netbsd-rc.d @@ -0,0 +1,28 @@ +#!/bin/sh +# $TheSupernovaDuo$ +# vim: ft=sh + +# PROVIDE: suika +# REQUIRE: DAEMON +# BEFORE: LOGIN +# KEYWORD: shutdown + +. /etc/rc.subr + +name="suika" +rcvar="${name}" +command="%%PREFIX/bin/${name}" +command_args="--config %%PREFIX%%/etc/suika/config" +pidfile="/var/run/${name}.pid" +start_cmd="${name}_start" + +suika_start() { + printf "Starting %s..." "${name}" + ${command} ${command_args} + pgrep -n ${name} > ${pidfile} +} + +load_rc_config ${name} +run_rc_command "$1" + + diff --git a/trunk/rc.d/openbsd-rc.d b/trunk/rc.d/openbsd-rc.d new file mode 100644 index 0000000..f0bf77c --- /dev/null +++ b/trunk/rc.d/openbsd-rc.d @@ -0,0 +1,12 @@ +#!/bin/ksh +# $TheSupernovaDuo$ +# vim: ft=sh + +daemon="%%PREFIX%%/bin/suika" +daemon_args="--config %%PREFIX%%/etc/suika/config" + +. /etc/rc.d/rc.subr + +rc_bg=YES + +rc_cmd "$1" diff --git a/trunk/rc.d/suika.service b/trunk/rc.d/suika.service new file mode 100644 index 0000000..d8d53b3 --- /dev/null +++ b/trunk/rc.d/suika.service @@ -0,0 +1,16 @@ +# $TheSupernovaDuo$ +# vim: ft=confini +[Unit] +Description=A drunk IRC bouncer +After=network.target +Wants=network.target +StartLimitBurst=5 +StartLimitIntervalSec=1 +[Service] +Type=simple +Restart=on-abnormal +RestartSec=1 +User=suika +ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config +[Install] +WantedBy=multi-user.target diff --git a/trunk/server.go b/trunk/server.go new file mode 100644 index 0000000..05c55fd --- /dev/null +++ b/trunk/server.go @@ -0,0 +1,330 @@ +package suika + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "gopkg.in/irc.v3" +) + +// TODO: make configurable +var ( + retryConnectMinDelay = time.Minute + retryConnectMaxDelay = 10 * time.Minute + retryConnectJitter = time.Minute + connectTimeout = 15 * time.Second + writeTimeout = 10 * time.Second + upstreamMessageDelay = 2 * time.Second + upstreamMessageBurst = 10 + backlogTimeout = 10 * time.Second + handleDownstreamMessageTimeout = 10 * time.Second + downstreamRegisterTimeout = 30 * time.Second + chatHistoryLimit = 1000 + backlogLimit = 4000 +) + +type Logger interface { + Printf(format string, v ...interface{}) + Debugf(format string, v ...interface{}) +} + +type logger struct { + *log.Logger + debug bool +} + +func (l logger) Debugf(format string, v ...interface{}) { + if !l.debug { + return + } + l.Logger.Printf(format, v...) +} + +func NewLogger(out io.Writer, debug bool) Logger { + return logger{ + Logger: log.New(log.Writer(), "", log.LstdFlags), + debug: debug, + } +} + +type prefixLogger struct { + logger Logger + prefix string +} + +var _ Logger = (*prefixLogger)(nil) + +func (l *prefixLogger) Printf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Printf("%v"+format, v...) +} + +func (l *prefixLogger) Debugf(format string, v ...interface{}) { + v = append([]interface{}{l.prefix}, v...) + l.logger.Debugf("%v"+format, v...) +} + +type int64Gauge struct { + v int64 // atomic +} + +func (g *int64Gauge) Add(delta int64) { + atomic.AddInt64(&g.v, delta) +} + +func (g *int64Gauge) Value() int64 { + return atomic.LoadInt64(&g.v) +} + +func (g *int64Gauge) Float64() float64 { + return float64(g.Value()) +} + +type retryListener struct { + net.Listener + Logger Logger + + delay time.Duration +} + +func (ln *retryListener) Accept() (net.Conn, error) { + for { + conn, err := ln.Listener.Accept() + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if ln.delay == 0 { + ln.delay = 5 * time.Millisecond + } else { + ln.delay *= 2 + } + if max := 1 * time.Second; ln.delay > max { + ln.delay = max + } + if ln.Logger != nil { + ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err) + } + time.Sleep(ln.delay) + } else { + ln.delay = 0 + return conn, err + } + } +} + +type Config struct { + Hostname string + Title string + LogPath string + MaxUserNetworks int + MultiUpstream bool + MOTD string + UpstreamUserIPs []*net.IPNet +} + +type Server struct { + Logger Logger + + config atomic.Value // *Config + db Database + stopWG sync.WaitGroup + + lock sync.Mutex + listeners map[net.Listener]struct{} + users map[string]*user +} + +func NewServer(db Database) *Server { + srv := &Server{ + Logger: NewLogger(log.Writer(), true), + db: db, + listeners: make(map[net.Listener]struct{}), + users: make(map[string]*user), + } + srv.config.Store(&Config{ + Hostname: "localhost", + MaxUserNetworks: -1, + MultiUpstream: true, + }) + return srv +} + +func (s *Server) prefix() *irc.Prefix { + return &irc.Prefix{Name: s.Config().Hostname} +} + +func (s *Server) Config() *Config { + return s.config.Load().(*Config) +} + +func (s *Server) SetConfig(cfg *Config) { + s.config.Store(cfg) +} + +func (s *Server) Start() error { + users, err := s.db.ListUsers(context.TODO()) + if err != nil { + return err + } + + s.lock.Lock() + for i := range users { + s.addUserLocked(&users[i]) + } + s.lock.Unlock() + + return nil +} + +func (s *Server) Shutdown() { + s.lock.Lock() + for ln := range s.listeners { + if err := ln.Close(); err != nil { + s.Logger.Printf("failed to stop listener: %v", err) + } + } + for _, u := range s.users { + u.events <- eventStop{} + } + s.lock.Unlock() + + s.stopWG.Wait() + + if err := s.db.Close(); err != nil { + s.Logger.Printf("failed to close DB: %v", err) + } +} + +func (s *Server) createUser(ctx context.Context, user *User) (*user, error) { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.users[user.Username]; ok { + return nil, fmt.Errorf("user %q already exists", user.Username) + } + + err := s.db.StoreUser(ctx, user) + if err != nil { + return nil, fmt.Errorf("could not create user in db: %v", err) + } + + return s.addUserLocked(user), nil +} + +func (s *Server) forEachUser(f func(*user)) { + s.lock.Lock() + for _, u := range s.users { + f(u) + } + s.lock.Unlock() +} + +func (s *Server) getUser(name string) *user { + s.lock.Lock() + u := s.users[name] + s.lock.Unlock() + return u +} + +func (s *Server) addUserLocked(user *User) *user { + s.Logger.Printf("starting bouncer for user %q", user.Username) + u := newUser(s, user) + s.users[u.Username] = u + + s.stopWG.Add(1) + + go func() { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack()) + } + + s.lock.Lock() + delete(s.users, u.Username) + s.lock.Unlock() + + s.stopWG.Done() + }() + + u.run() + }() + + return u +} + +var lastDownstreamID uint64 = 0 + +func (s *Server) handle(ic ircConn) { + defer func() { + if err := recover(); err != nil { + s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack()) + } + }() + + id := atomic.AddUint64(&lastDownstreamID, 1) + dc := newDownstreamConn(s, ic, id) + if err := dc.runUntilRegistered(); err != nil { + if !errors.Is(err, io.EOF) { + dc.logger.Printf("%v", err) + } + } else { + dc.user.events <- eventDownstreamConnected{dc} + if err := dc.readMessages(dc.user.events); err != nil { + dc.logger.Printf("%v", err) + } + dc.user.events <- eventDownstreamDisconnected{dc} + } + dc.Close() +} + +func (s *Server) Serve(ln net.Listener) error { + ln = &retryListener{ + Listener: ln, + Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())}, + } + + s.lock.Lock() + s.listeners[ln] = struct{}{} + s.lock.Unlock() + + s.stopWG.Add(1) + + defer func() { + s.lock.Lock() + delete(s.listeners, ln) + s.lock.Unlock() + + s.stopWG.Done() + }() + + for { + conn, err := ln.Accept() + if isErrClosed(err) { + return nil + } else if err != nil { + return fmt.Errorf("failed to accept connection: %v", err) + } + + go s.handle(newNetIRCConn(conn)) + } +} + +type ServerStats struct { + Users int + Downstreams int64 + Upstreams int64 +} + +func (s *Server) Stats() *ServerStats { + var stats ServerStats + s.lock.Lock() + stats.Users = len(s.users) + s.lock.Unlock() + return &stats +} diff --git a/trunk/server_test.go b/trunk/server_test.go new file mode 100644 index 0000000..cf0adca --- /dev/null +++ b/trunk/server_test.go @@ -0,0 +1,207 @@ +package suika + +import ( + "context" + "net" + "testing" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +var testServerPrefix = &irc.Prefix{Name: "suika-test-server"} + +const ( + testUsername = "suika-test-user" + testPassword = testUsername +) + +func createTempSqliteDB(t *testing.T) Database { + db, err := OpenDB("sqlite3", ":memory:") + if err != nil { + t.Fatalf("failed to create temporary SQLite database: %v", err) + } + // :memory: will open a separate database for each new connection. Make + // sure the sql package only uses a single connection. An alternative + // solution is to use "file::memory:?cache=shared". + db.(*SqliteDB).db.SetMaxOpenConns(1) + return db +} + +func createTempPostgresDB(t *testing.T) Database { + db := &PostgresDB{db: openTempPostgresDB(t)} + if err := db.upgrade(); err != nil { + t.Fatalf("failed to upgrade PostgreSQL database: %v", err) + } + + return db +} + +func createTestUser(t *testing.T, db Database) *User { + hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost) + if err != nil { + t.Fatalf("failed to generate bcrypt hash: %v", err) + } + + record := &User{Username: testUsername, Password: string(hashed)} + if err := db.StoreUser(context.Background(), record); err != nil { + t.Fatalf("failed to store test user: %v", err) + } + + return record +} + +func createTestDownstream(t *testing.T, srv *Server) ircConn { + c1, c2 := net.Pipe() + go srv.handle(newNetIRCConn(c1)) + return newNetIRCConn(c2) +} + +func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to create TCP listener: %v", err) + } + + network := &Network{ + Name: "testnet", + Addr: "irc://" + ln.Addr().String(), + Nick: user.Username, + Enabled: true, + } + if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil { + t.Fatalf("failed to store test network: %v", err) + } + + return network, ln +} + +func mustAccept(t *testing.T, ln net.Listener) ircConn { + c, err := ln.Accept() + if err != nil { + t.Fatalf("failed accepting connection: %v", err) + } + return newNetIRCConn(c) +} + +func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message { + msg, err := c.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message (want %q): %v", cmd, err) + } + if msg.Command != cmd { + t.Fatalf("invalid message received: want %q, got: %v", cmd, msg) + } + return msg +} + +func registerDownstreamConn(t *testing.T, c ircConn, network *Network) { + c.WriteMessage(&irc.Message{ + Command: "PASS", + Params: []string{testPassword}, + }) + c.WriteMessage(&irc.Message{ + Command: "NICK", + Params: []string{testUsername}, + }) + c.WriteMessage(&irc.Message{ + Command: "USER", + Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername}, + }) + + expectMessage(t, c, irc.RPL_WELCOME) +} + +func registerUpstreamConn(t *testing.T, c ircConn) { + msg := expectMessage(t, c, "CAP") + if msg.Params[0] != "LS" { + t.Fatalf("invalid CAP LS: got: %v", msg) + } + msg = expectMessage(t, c, "NICK") + nick := msg.Params[0] + if nick != testUsername { + t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg) + } + expectMessage(t, c, "USER") + + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_WELCOME, + Params: []string{nick, "Welcome!"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_YOURHOST, + Params: []string{nick, "Your host is suika-test-server"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_CREATED, + Params: []string{nick, "Who cares when the server was created?"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.RPL_MYINFO, + Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"}, + }) + c.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: irc.ERR_NOMOTD, + Params: []string{nick, "No MOTD"}, + }) +} + +func testServer(t *testing.T, db Database) { + user := createTestUser(t, db) + network, upstream := createTestUpstream(t, db, user) + defer upstream.Close() + + srv := NewServer(db) + if err := srv.Start(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer srv.Shutdown() + + uc := mustAccept(t, upstream) + defer uc.Close() + registerUpstreamConn(t, uc) + + dc := createTestDownstream(t, srv) + defer dc.Close() + registerDownstreamConn(t, dc, network) + + noticeText := "This is a very important server notice." + uc.WriteMessage(&irc.Message{ + Prefix: testServerPrefix, + Command: "NOTICE", + Params: []string{testUsername, noticeText}, + }) + + var msg *irc.Message + for { + var err error + msg, err = dc.ReadMessage() + if err != nil { + t.Fatalf("failed to read IRC message: %v", err) + } + if msg.Command == "NOTICE" { + break + } + } + + if msg.Params[1] != noticeText { + t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg) + } +} + +func TestServer(t *testing.T) { + t.Run("sqlite", func(t *testing.T) { + db := createTempSqliteDB(t) + testServer(t, db) + }) + + t.Run("postgres", func(t *testing.T) { + db := createTempPostgresDB(t) + testServer(t, db) + }) +} diff --git a/trunk/service.go b/trunk/service.go new file mode 100644 index 0000000..07226ef --- /dev/null +++ b/trunk/service.go @@ -0,0 +1,1150 @@ +package suika + +import ( + "context" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "flag" + "fmt" + "io/ioutil" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "golang.org/x/crypto/bcrypt" + "gopkg.in/irc.v3" +) + +const ( + serviceNick = "BouncerServ" + serviceNickCM = "bouncerserv" + serviceRealname = "suika bouncer service" +) + +// maxRSABits is the maximum number of RSA key bits used when generating a new +// private key. +const maxRSABits = 8192 + +var servicePrefix = &irc.Prefix{ + Name: serviceNick, + User: serviceNick, + Host: serviceNick, +} + +type serviceCommandSet map[string]*serviceCommand + +type serviceCommand struct { + usage string + desc string + handle func(ctx context.Context, dc *downstreamConn, params []string) error + children serviceCommandSet + admin bool +} + +func sendServiceNOTICE(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{dc.nick, text}, + }) +} + +func sendServicePRIVMSG(dc *downstreamConn, text string) { + dc.SendMessage(&irc.Message{ + Prefix: servicePrefix, + Command: "PRIVMSG", + Params: []string{dc.nick, text}, + }) +} + +func splitWords(s string) ([]string, error) { + var words []string + var lastWord strings.Builder + escape := false + prev := ' ' + wordDelim := ' ' + + for _, r := range s { + if escape { + // last char was a backslash, write the byte as-is. + lastWord.WriteRune(r) + escape = false + } else if r == '\\' { + escape = true + } else if wordDelim == ' ' && unicode.IsSpace(r) { + // end of last word + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + lastWord.Reset() + } + } else if r == wordDelim { + // wordDelim is either " or ', switch back to + // space-delimited words. + wordDelim = ' ' + } else if r == '"' || r == '\'' { + if wordDelim == ' ' { + // start of (double-)quoted word + wordDelim = r + } else { + // either wordDelim is " and r is ' or vice-versa + lastWord.WriteRune(r) + } + } else { + lastWord.WriteRune(r) + } + + prev = r + } + + if !unicode.IsSpace(prev) { + words = append(words, lastWord.String()) + } + + if wordDelim != ' ' { + return nil, fmt.Errorf("unterminated quoted string") + } + if escape { + return nil, fmt.Errorf("unterminated backslash sequence") + } + + return words, nil +} + +func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) { + words, err := splitWords(text) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err)) + return + } + + cmd, params, err := serviceCommands.Get(words) + if err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err)) + return + } + if cmd.admin && !dc.user.Admin { + sendServicePRIVMSG(dc, "error: you must be an admin to use this command") + return + } + + if cmd.handle == nil { + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + // Pretend the command does not exist if it has neither children nor handler. + // This is obviously a bug but it is better to not die anyway. + dc.logger.Printf("command without handler and subcommands invoked:", words[0]) + sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0])) + } + return + } + + if err := cmd.handle(ctx, dc, params); err != nil { + sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err)) + } +} + +func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) { + if len(params) == 0 { + return nil, nil, fmt.Errorf("no command specified") + } + + name := params[0] + params = params[1:] + + cmd, ok := cmds[name] + if !ok { + for k := range cmds { + if !strings.HasPrefix(k, name) { + continue + } + if cmd != nil { + return nil, params, fmt.Errorf("command %q is ambiguous", name) + } + cmd = cmds[k] + } + } + if cmd == nil { + return nil, params, fmt.Errorf("command %q not found", name) + } + + if len(params) == 0 || len(cmd.children) == 0 { + return cmd, params, nil + } + return cmd.children.Get(params) +} + +func (cmds serviceCommandSet) Names() []string { + l := make([]string, 0, len(cmds)) + for name := range cmds { + l = append(l, name) + } + sort.Strings(l) + return l +} + +var serviceCommands serviceCommandSet + +func init() { + serviceCommands = serviceCommandSet{ + "help": { + usage: "[command]", + desc: "print help message", + handle: handleServiceHelp, + }, + "network": { + children: serviceCommandSet{ + "create": { + usage: "-addr [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "add a new network", + handle: handleServiceNetworkCreate, + }, + "status": { + desc: "show a list of saved networks and their current status", + handle: handleServiceNetworkStatus, + }, + "update": { + usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...", + desc: "update a network", + handle: handleServiceNetworkUpdate, + }, + "delete": { + usage: "[name]", + desc: "delete a network", + handle: handleServiceNetworkDelete, + }, + "quote": { + usage: "[name] ", + desc: "send a raw line to a network", + handle: handleServiceNetworkQuote, + }, + }, + }, + "certfp": { + children: serviceCommandSet{ + "generate": { + usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]", + desc: "generate a new self-signed certificate, defaults to using RSA-3072 key", + handle: handleServiceCertFPGenerate, + }, + "fingerprint": { + usage: "[-network name]", + desc: "show fingerprints of certificate", + handle: handleServiceCertFPFingerprints, + }, + }, + }, + "sasl": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show SASL status", + handle: handleServiceSASLStatus, + }, + "set-plain": { + usage: "[-network name] ", + desc: "set SASL PLAIN credentials", + handle: handleServiceSASLSetPlain, + }, + "reset": { + usage: "[-network name]", + desc: "disable SASL authentication and remove stored credentials", + handle: handleServiceSASLReset, + }, + }, + }, + "user": { + children: serviceCommandSet{ + "create": { + usage: "-username -password [-realname ] [-admin]", + desc: "create a new suika user", + handle: handleUserCreate, + admin: true, + }, + "update": { + usage: "[-password ] [-realname ]", + desc: "update the current user", + handle: handleUserUpdate, + }, + "delete": { + usage: "", + desc: "delete a user", + handle: handleUserDelete, + admin: true, + }, + }, + }, + "channel": { + children: serviceCommandSet{ + "status": { + usage: "[-network name]", + desc: "show a list of saved channels and their current status", + handle: handleServiceChannelStatus, + }, + "update": { + usage: " [-relay-detached ] [-reattach-on ] [-detach-after ] [-detach-on ]", + desc: "update a channel", + handle: handleServiceChannelUpdate, + }, + }, + }, + "server": { + children: serviceCommandSet{ + "status": { + desc: "show server statistics", + handle: handleServiceServerStatus, + admin: true, + }, + "notice": { + desc: "broadcast a notice to all connected bouncer users", + handle: handleServiceServerNotice, + admin: true, + }, + }, + admin: true, + }, + } +} + +func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) { + for _, name := range cmds.Names() { + cmd := cmds[name] + if cmd.admin && !admin { + continue + } + words := append(prefix, name) + if len(cmd.children) == 0 { + s := strings.Join(words, " ") + *l = append(*l, s) + } else { + appendServiceCommandSetHelp(cmd.children, words, admin, l) + } + } +} + +func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) > 0 { + cmd, rest, err := serviceCommands.Get(params) + if err != nil { + return err + } + words := params[:len(params)-len(rest)] + + if len(cmd.children) > 0 { + var l []string + appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } else { + text := strings.Join(words, " ") + if cmd.usage != "" { + text += " " + cmd.usage + } + text += ": " + cmd.desc + + sendServicePRIVMSG(dc, text) + } + } else { + var l []string + appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l) + sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", ")) + } + return nil +} + +func newFlagSet() *flag.FlagSet { + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.SetOutput(ioutil.Discard) + return fs +} + +type stringSliceFlag []string + +func (v *stringSliceFlag) String() string { + return fmt.Sprint([]string(*v)) +} + +func (v *stringSliceFlag) Set(s string) error { + *v = append(*v, s) + return nil +} + +// stringPtrFlag is a flag value populating a string pointer. This allows to +// disambiguate between a flag that hasn't been set and a flag that has been +// set to an empty string. +type stringPtrFlag struct { + ptr **string +} + +func (f stringPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return **f.ptr +} + +func (f stringPtrFlag) Set(s string) error { + *f.ptr = &s + return nil +} + +type boolPtrFlag struct { + ptr **bool +} + +func (f boolPtrFlag) String() string { + if f.ptr == nil || *f.ptr == nil { + return "" + } + return strconv.FormatBool(**f.ptr) +} + +func (f boolPtrFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + *f.ptr = &v + return nil +} + +func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) { + name, params := popArg(params) + if name == "" { + if dc.network == nil { + return nil, params, fmt.Errorf("no network selected, a name argument is required") + } + return dc.network, params, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, params, fmt.Errorf("unknown network %q", name) + } + return net, params, nil + } +} + +type networkFlagSet struct { + *flag.FlagSet + Addr, Name, Nick, Username, Pass, Realname *string + Enabled *bool + ConnectCommands []string +} + +func newNetworkFlagSet() *networkFlagSet { + fs := &networkFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.Addr}, "addr", "") + fs.Var(stringPtrFlag{&fs.Name}, "name", "") + fs.Var(stringPtrFlag{&fs.Nick}, "nick", "") + fs.Var(stringPtrFlag{&fs.Username}, "username", "") + fs.Var(stringPtrFlag{&fs.Pass}, "pass", "") + fs.Var(stringPtrFlag{&fs.Realname}, "realname", "") + fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "") + fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "") + return fs +} + +func (fs *networkFlagSet) update(network *Network) error { + if fs.Addr != nil { + if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 { + scheme := addrParts[0] + switch scheme { + case "ircs", "irc", "unix": + default: + return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme) + } + } + network.Addr = *fs.Addr + } + if fs.Name != nil { + network.Name = *fs.Name + } + if fs.Nick != nil { + network.Nick = *fs.Nick + } + if fs.Username != nil { + network.Username = *fs.Username + } + if fs.Pass != nil { + network.Pass = *fs.Pass + } + if fs.Realname != nil { + network.Realname = *fs.Realname + } + if fs.Enabled != nil { + network.Enabled = *fs.Enabled + } + if fs.ConnectCommands != nil { + if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" { + network.ConnectCommands = nil + } else { + for _, command := range fs.ConnectCommands { + _, err := irc.ParseMessage(command) + if err != nil { + return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err) + } + } + network.ConnectCommands = fs.ConnectCommands + } + } + return nil +} + +func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + if fs.Addr == nil { + return fmt.Errorf("flag -addr is required") + } + + record := &Network{ + Addr: *fs.Addr, + Enabled: true, + } + if err := fs.update(record); err != nil { + return err + } + + network, err := dc.user.createNetwork(ctx, record) + if err != nil { + return fmt.Errorf("could not create network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName())) + return nil +} + +func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error { + n := 0 + for _, net := range dc.user.networks { + var statuses []string + var details string + if uc := net.conn; uc != nil { + if dc.nick != uc.nick { + statuses = append(statuses, "connected as "+uc.nick) + } else { + statuses = append(statuses, "connected") + } + details = fmt.Sprintf("%v channels", uc.channels.Len()) + } else if !net.Enabled { + statuses = append(statuses, "disabled") + } else { + statuses = append(statuses, "disconnected") + if net.lastError != nil { + details = net.lastError.Error() + } + } + + if net == dc.network { + statuses = append(statuses, "current") + } + + name := net.GetName() + if name != net.Addr { + name = fmt.Sprintf("%v (%v)", name, net.Addr) + } + + s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", ")) + if details != "" { + s += ": " + details + } + sendServicePRIVMSG(dc, s) + + n++ + } + + if n == 0 { + sendServicePRIVMSG(dc, `No network configured, add one with "network create".`) + } + + return nil +} + +func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + fs := newNetworkFlagSet() + if err := fs.Parse(params); err != nil { + return err + } + + record := net.Network // copy network record because we'll mutate it + if err := fs.update(&record); err != nil { + return err + } + + network, err := dc.user.updateNetwork(ctx, &record) + if err != nil { + return fmt.Errorf("could not update network: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName())) + return nil +} + +func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error { + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + if err := dc.user.deleteNetwork(ctx, net.ID); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName())) + return nil +} + +func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 && len(params) != 2 { + return fmt.Errorf("expected one or two arguments") + } + + raw := params[len(params)-1] + params = params[:len(params)-1] + + net, params, err := getNetworkFromArg(dc, params) + if err != nil { + return err + } + + uc := net.conn + if uc == nil { + return fmt.Errorf("network %q is not currently connected", net.GetName()) + } + + m, err := irc.ParseMessage(raw) + if err != nil { + return fmt.Errorf("failed to parse command %q: %v", raw, err) + } + uc.SendMessage(ctx, m) + + sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName())) + return nil +} + +func sendCertfpFingerprints(dc *downstreamConn, cert []byte) { + sha1Sum := sha1.Sum(cert) + sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:])) + sha256Sum := sha256.Sum256(cert) + sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:])) + sha512Sum := sha512.Sum512(cert) + sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:])) +} + +func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) { + if name == "" { + if dc.network == nil { + return nil, fmt.Errorf("no network selected, -network is required") + } + return dc.network, nil + } else { + net := dc.user.getNetwork(name) + if net == nil { + return nil, fmt.Errorf("unknown network %q", name) + } + return net, nil + } +} + +func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)") + bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA") + + if err := fs.Parse(params); err != nil { + return err + } + + if *bits <= 0 || *bits > maxRSABits { + return fmt.Errorf("invalid value for -bits") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + privKey, cert, err := generateCertFP(*keyType, *bits) + if err != nil { + return err + } + + net.SASL.External.CertBlob = cert + net.SASL.External.PrivKeyBlob = privKey + net.SASL.Mechanism = "EXTERNAL" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "certificate generated") + sendCertfpFingerprints(dc, cert) + return nil +} + +func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + if net.SASL.Mechanism != "EXTERNAL" { + return fmt.Errorf("CertFP not set up") + } + + sendCertfpFingerprints(dc, net.SASL.External.CertBlob) + return nil +} + +func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + switch net.SASL.Mechanism { + case "PLAIN": + sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username)) + case "EXTERNAL": + sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled") + case "": + sendServicePRIVMSG(dc, "SASL is disabled") + } + + if uc := net.conn; uc != nil { + if uc.account != "" { + sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account)) + } else { + sendServicePRIVMSG(dc, "Unauthenticated on upstream network") + } + } else { + sendServicePRIVMSG(dc, "Disconnected from upstream network") + } + + return nil +} + +func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + if len(fs.Args()) != 2 { + return fmt.Errorf("expected exactly 2 arguments") + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = fs.Arg(0) + net.SASL.Plain.Password = fs.Arg(1) + net.SASL.Mechanism = "PLAIN" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials saved") + return nil +} + +func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + netName := fs.String("network", "", "select a network") + + if err := fs.Parse(params); err != nil { + return err + } + + net, err := getNetworkFromFlag(dc, *netName) + if err != nil { + return err + } + + net.SASL.Plain.Username = "" + net.SASL.Plain.Password = "" + net.SASL.External.CertBlob = nil + net.SASL.External.PrivKeyBlob = nil + net.SASL.Mechanism = "" + + if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil { + return err + } + + sendServicePRIVMSG(dc, "credentials reset") + return nil +} + +func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error { + fs := newFlagSet() + username := fs.String("username", "", "") + password := fs.String("password", "", "") + realname := fs.String("realname", "", "") + admin := fs.Bool("admin", false, "") + + if err := fs.Parse(params); err != nil { + return err + } + if *username == "" { + return fmt.Errorf("flag -username is required") + } + if *password == "" { + return fmt.Errorf("flag -password is required") + } + + hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + + user := &User{ + Username: *username, + Password: string(hashed), + Realname: *realname, + Admin: *admin, + } + if _, err := dc.srv.createUser(ctx, user); err != nil { + return fmt.Errorf("could not create user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username)) + return nil +} + +func popArg(params []string) (string, []string) { + if len(params) > 0 && !strings.HasPrefix(params[0], "-") { + return params[0], params[1:] + } + return "", params +} + +func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + var password, realname *string + var admin *bool + fs := newFlagSet() + fs.Var(stringPtrFlag{&password}, "password", "") + fs.Var(stringPtrFlag{&realname}, "realname", "") + fs.Var(boolPtrFlag{&admin}, "admin", "") + + username, params := popArg(params) + if err := fs.Parse(params); err != nil { + return err + } + if len(fs.Args()) > 0 { + return fmt.Errorf("unexpected argument") + } + + var hashed *string + if password != nil { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %v", err) + } + hashedStr := string(hashedBytes) + hashed = &hashedStr + } + + if username != "" && username != dc.user.Username { + if !dc.user.Admin { + return fmt.Errorf("you must be an admin to update other users") + } + if realname != nil { + return fmt.Errorf("cannot update -realname of other user") + } + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + done := make(chan error, 1) + event := eventUserUpdate{ + password: hashed, + admin: admin, + done: done, + } + select { + case <-ctx.Done(): + return ctx.Err() + case u.events <- event: + } + // TODO: send context to the other side + if err := <-done; err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username)) + } else { + // copy the user record because we'll mutate it + record := dc.user.User + + if hashed != nil { + record.Password = *hashed + } + if realname != nil { + record.Realname = *realname + } + if admin != nil { + return fmt.Errorf("cannot update -admin of own user") + } + + if err := dc.user.updateUser(ctx, &record); err != nil { + return err + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username)) + } + + return nil +} + +func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + username := params[0] + + u := dc.srv.getUser(username) + if u == nil { + return fmt.Errorf("unknown username %q", username) + } + + u.stop() + + if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil { + return fmt.Errorf("failed to delete user: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username)) + return nil +} + +func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error { + var defaultNetworkName string + if dc.network != nil { + defaultNetworkName = dc.network.GetName() + } + + fs := newFlagSet() + networkName := fs.String("network", defaultNetworkName, "") + + if err := fs.Parse(params); err != nil { + return err + } + + n := 0 + + sendNetwork := func(net *network) { + var channels []*Channel + for _, entry := range net.channels.innerMap { + channels = append(channels, entry.value.(*Channel)) + } + + sort.Slice(channels, func(i, j int) bool { + return strings.ReplaceAll(channels[i].Name, "#", "") < + strings.ReplaceAll(channels[j].Name, "#", "") + }) + + for _, ch := range channels { + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + } + + name := ch.Name + if *networkName == "" { + name += "/" + net.GetName() + } + + var status string + if uch != nil { + status = "joined" + } else if net.conn != nil { + status = "parted" + } else { + status = "disconnected" + } + + if ch.Detached { + status += ", detached" + } + + s := fmt.Sprintf("%v [%v]", name, status) + sendServicePRIVMSG(dc, s) + + n++ + } + } + + if *networkName == "" { + for _, net := range dc.user.networks { + sendNetwork(net) + } + } else { + net := dc.user.getNetwork(*networkName) + if net == nil { + return fmt.Errorf("unknown network %q", *networkName) + } + sendNetwork(net) + } + + if n == 0 { + sendServicePRIVMSG(dc, "No channel configured.") + } + + return nil +} + +type channelFlagSet struct { + *flag.FlagSet + RelayDetached, ReattachOn, DetachAfter, DetachOn *string +} + +func newChannelFlagSet() *channelFlagSet { + fs := &channelFlagSet{FlagSet: newFlagSet()} + fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "") + fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "") + fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "") + fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "") + return fs +} + +func (fs *channelFlagSet) update(channel *Channel) error { + if fs.RelayDetached != nil { + filter, err := parseFilter(*fs.RelayDetached) + if err != nil { + return err + } + channel.RelayDetached = filter + } + if fs.ReattachOn != nil { + filter, err := parseFilter(*fs.ReattachOn) + if err != nil { + return err + } + channel.ReattachOn = filter + } + if fs.DetachAfter != nil { + dur, err := time.ParseDuration(*fs.DetachAfter) + if err != nil || dur < 0 { + return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter) + } + channel.DetachAfter = dur + } + if fs.DetachOn != nil { + filter, err := parseFilter(*fs.DetachOn) + if err != nil { + return err + } + channel.DetachOn = filter + } + return nil +} + +func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) < 1 { + return fmt.Errorf("expected at least one argument") + } + name := params[0] + + fs := newChannelFlagSet() + if err := fs.Parse(params[1:]); err != nil { + return err + } + + uc, upstreamName, err := dc.unmarshalEntity(name) + if err != nil { + return fmt.Errorf("unknown channel %q", name) + } + + ch := uc.network.channels.Value(upstreamName) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + + if err := fs.update(ch); err != nil { + return err + } + + uc.updateChannelAutoDetach(upstreamName) + + if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + return fmt.Errorf("failed to update channel: %v", err) + } + + sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name)) + return nil +} +func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error { + dbStats, err := dc.user.srv.db.Stats(ctx) + if err != nil { + return err + } + serverStats := dc.user.srv.Stats() + sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels)) + return nil +} + +func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error { + if len(params) != 1 { + return fmt.Errorf("expected exactly one argument") + } + text := params[0] + + dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text) + + broadcastMsg := &irc.Message{ + Prefix: servicePrefix, + Command: "NOTICE", + Params: []string{"$" + dc.srv.Config().Hostname, text}, + } + var err error + sent := 0 + total := 0 + dc.srv.forEachUser(func(u *user) { + total++ + select { + case <-ctx.Done(): + err = ctx.Err() + case u.events <- eventBroadcast{broadcastMsg}: + sent++ + } + }) + + dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total) + sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total)) + + return err +} diff --git a/trunk/service_test.go b/trunk/service_test.go new file mode 100644 index 0000000..1f3fe6b --- /dev/null +++ b/trunk/service_test.go @@ -0,0 +1,54 @@ +package suika + +import ( + "testing" +) + +func assertSplit(t *testing.T, input string, expected []string) { + actual, err := splitWords(input) + if err != nil { + t.Errorf("%q: %v", input, err) + return + } + if len(actual) != len(expected) { + t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual) + return + } + for i := 0; i < len(actual); i++ { + if actual[i] != expected[i] { + t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual) + } + } +} + +func TestSplit(t *testing.T) { + assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{ + "ch", + "up", + "#suika", + "relay-detached", + "message", + }) + assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{ + "net", + "update", + "\"free\"node", + "-pass", + "political \"stance\" desu!", + "-realname", + "", + "-nick", + "lee", + }) + assertSplit(t, "Omedeto,\\ Yui! ''", []string{ + "Omedeto, Yui!", + "", + }) + + if _, err := splitWords("end of 'file"); err == nil { + t.Errorf("expected error on unterminated single quote") + } + if _, err := splitWords("end of backquote \\"); err == nil { + t.Errorf("expected error on unterminated backquote sequence") + } +} diff --git a/trunk/suika_psql_schema.sql b/trunk/suika_psql_schema.sql new file mode 100644 index 0000000..d148b4d --- /dev/null +++ b/trunk/suika_psql_schema.sql @@ -0,0 +1,58 @@ +CREATE TABLE IF NOT EXISTS "User" ( + id SERIAL PRIMARY KEY, + username VARCHAR(255) NOT NULL UNIQUE, + password VARCHAR(255), + admin BOOLEAN NOT NULL DEFAULT FALSE, + realname VARCHAR(255) +); + +CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); + +CREATE TABLE IF NOT EXISTS "Network" ( + id SERIAL PRIMARY KEY, + name VARCHAR(255), + "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255), + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism sasl_mechanism, + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BYTEA, + sasl_external_key BYTEA, + enabled BOOLEAN NOT NULL DEFAULT TRUE, + UNIQUE("user", addr, nick), + UNIQUE("user", name) +); +CREATE TABLE IF NOT EXISTS "Channel" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + key VARCHAR(255), + detached BOOLEAN NOT NULL DEFAULT FALSE, + detached_internal_msgid VARCHAR(255), + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + UNIQUE(network, name) +); +CREATE TABLE IF NOT EXISTS "DeliveryReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + client VARCHAR(255) NOT NULL DEFAULT '', + internal_msgid VARCHAR(255) NOT NULL, + UNIQUE(network, target, client) +); +CREATE TABLE IF NOT EXISTS "ReadReceipt" ( + id SERIAL PRIMARY KEY, + network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE, + target VARCHAR(255) NOT NULL, + timestamp TIMESTAMP WITH TIME ZONE NOT NULL, + UNIQUE(network, target) +); + diff --git a/trunk/suika_sqlite_schema.sql b/trunk/suika_sqlite_schema.sql new file mode 100644 index 0000000..517d06f --- /dev/null +++ b/trunk/suika_sqlite_schema.sql @@ -0,0 +1,61 @@ +CREATE TABLE IF NOT EXISTS User ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password TEXT, + admin INTEGER NOT NULL DEFAULT 0, + realname TEXT +); +CREATE TABLE IF NOT EXISTS Network ( + id INTEGER PRIMARY KEY, + name TEXT, + user INTEGER NOT NULL, + addr TEXT NOT NULL, + nick TEXT, + username TEXT, + realname TEXT, + pass TEXT, + connect_commands TEXT, + sasl_mechanism TEXT, + sasl_plain_username TEXT, + sasl_plain_password TEXT, + sasl_external_cert BLOB, + sasl_external_key BLOB, + enabled INTEGER NOT NULL DEFAULT 1, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) +); +CREATE TABLE IF NOT EXISTS Channel ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + name TEXT NOT NULL, + key TEXT, + detached INTEGER NOT NULL DEFAULT 0, + detached_internal_msgid TEXT, + relay_detached INTEGER NOT NULL DEFAULT 0, + reattach_on INTEGER NOT NULL DEFAULT 0, + detach_after INTEGER NOT NULL DEFAULT 0, + detach_on INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, name) +); + +CREATE TABLE IF NOT EXISTS DeliveryReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + client TEXT, + internal_msgid TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target, client) +); + +CREATE TABLE IF NOT EXISTS ReadReceipt ( + id INTEGER PRIMARY KEY, + network INTEGER NOT NULL, + target TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY(network) REFERENCES Network(id), + UNIQUE(network, target) +); + diff --git a/trunk/upstream.go b/trunk/upstream.go new file mode 100644 index 0000000..285dca4 --- /dev/null +++ b/trunk/upstream.go @@ -0,0 +1,2182 @@ +package suika + +import ( + "context" + "crypto" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + "github.com/emersion/go-sasl" + "gopkg.in/irc.v3" +) + +// permanentUpstreamCaps is the static list of upstream capabilities always +// requested when supported. +var permanentUpstreamCaps = map[string]bool{ + "account-notify": true, + "account-tag": true, + "away-notify": true, + "batch": true, + "extended-join": true, + "invite-notify": true, + "labeled-response": true, + "message-tags": true, + "multi-prefix": true, + "sasl": true, + "server-time": true, + "setname": true, + + "draft/account-registration": true, + "draft/extended-monitor": true, +} + +type registrationError struct { + *irc.Message +} + +func (err registrationError) Error() string { + return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason()) +} + +func (err registrationError) Reason() string { + if len(err.Params) > 0 { + return err.Params[len(err.Params)-1] + } + return err.Command +} + +func (err registrationError) Temporary() bool { + // Only return false if we're 100% sure that fixing the error requires a + // network configuration change + switch err.Command { + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME: + return false + case "FAIL": + return err.Params[1] != "ACCOUNT_REQUIRED" + default: + return true + } +} + +type upstreamChannel struct { + Name string + conn *upstreamConn + Topic string + TopicWho *irc.Prefix + TopicTime time.Time + Status channelStatus + modes channelModes + creationTime string + Members membershipsCasemapMap + complete bool + detachTimer *time.Timer +} + +func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) { + if uc.detachTimer != nil { + uc.detachTimer.Stop() + uc.detachTimer = nil + } + + if dur == 0 { + return + } + + uc.detachTimer = time.AfterFunc(dur, func() { + uc.conn.network.user.events <- eventChannelDetach{ + uc: uc.conn, + name: uc.Name, + } + }) +} + +type pendingUpstreamCommand struct { + downstreamID uint64 + msg *irc.Message +} + +type upstreamConn struct { + conn + + network *network + user *user + + serverName string + availableUserModes string + availableChannelModes map[byte]channelModeType + availableChannelTypes string + availableMemberships []membership + isupport map[string]*string + + registered bool + nick string + nickCM string + username string + realname string + modes userModes + channels upstreamChannelCasemapMap + supportedCaps map[string]string + caps map[string]bool + batches map[string]batch + away bool + account string + nextLabelID uint64 + monitored monitorCasemapMap + + saslClient sasl.Client + saslStarted bool + + casemapIsSet bool + + // Queue of commands in progress, indexed by type. The first entry has been + // sent to the server and is awaiting reply. The following entries have not + // been sent yet. + pendingCmds map[string][]pendingUpstreamCommand + + gotMotd bool +} + +func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) { + logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())} + + dialer := net.Dialer{Timeout: connectTimeout} + + u, err := network.URL() + if err != nil { + return nil, err + } + + var netConn net.Conn + switch u.Scheme { + case "ircs": + addr := u.Host + host, _, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Host + addr = u.Host + ":6697" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to TLS server at address %q", addr) + + tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}} + if network.SASL.Mechanism == "EXTERNAL" { + if network.SASL.External.CertBlob == nil { + return nil, fmt.Errorf("missing certificate for authentication") + } + if network.SASL.External.PrivKeyBlob == nil { + return nil, fmt.Errorf("missing private key for authentication") + } + key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + tlsConfig.Certificates = []tls.Certificate{ + { + Certificate: [][]byte{network.SASL.External.CertBlob}, + PrivateKey: key.(crypto.PrivateKey), + }, + } + logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) + } + + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + + // Don't do the TLS handshake immediately, because we need to register + // the new connection with identd ASAP. + netConn = tls.Client(netConn, tlsConfig) + case "irc": + addr := u.Host + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = u.Host + addr = u.Host + ":6667" + } + + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + logger.Printf("connecting to plain-text server at address %q", addr) + netConn, err = dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } + case "irc+unix", "unix": + logger.Printf("connecting to Unix socket at path %q", u.Path) + netConn, err = dialer.DialContext(ctx, "unix", u.Path) + if err != nil { + return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err) + } + default: + return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme) + } + + options := connOptions{ + Logger: logger, + RateLimitDelay: upstreamMessageDelay, + RateLimitBurst: upstreamMessageBurst, + } + + uc := &upstreamConn{ + conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), + network: network, + user: network.user, + channels: upstreamChannelCasemapMap{newCasemapMap(0)}, + supportedCaps: make(map[string]string), + caps: make(map[string]bool), + batches: make(map[string]batch), + availableChannelTypes: stdChannelTypes, + availableChannelModes: stdChannelModes, + availableMemberships: stdMemberships, + isupport: make(map[string]*string), + pendingCmds: make(map[string][]pendingUpstreamCommand), + monitored: monitorCasemapMap{newCasemapMap(0)}, + } + return uc, nil +} + +func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { + uc.network.forEachDownstream(f) +} + +func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) { + uc.forEachDownstream(func(dc *downstreamConn) { + if id != 0 && id != dc.id { + return + } + f(dc) + }) +} + +func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn { + for _, dc := range uc.user.downstreamConns { + if dc.id == id { + return dc + } + } + return nil +} + +func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { + ch := uc.channels.Value(name) + if ch == nil { + return nil, fmt.Errorf("unknown channel %q", name) + } + return ch, nil +} + +func (uc *upstreamConn) isChannel(entity string) bool { + return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0])) +} + +func (uc *upstreamConn) isOurNick(nick string) bool { + return uc.nickCM == uc.network.casemap(nick) +} + +func (uc *upstreamConn) abortPendingCommands() { + for _, l := range uc.pendingCmds { + for _, pendingCmd := range l { + dc := uc.downstreamByID(pendingCmd.downstreamID) + if dc == nil { + continue + } + + switch pendingCmd.msg.Command { + case "LIST": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "Command aborted"}, + }) + case "WHO": + mask := "*" + if len(pendingCmd.msg.Params) > 0 { + mask = pendingCmd.msg.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "Command aborted"}, + }) + case "AUTHENTICATE": + dc.endSASL(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.ERR_SASLABORTED, + Params: []string{dc.nick, "SASL authentication aborted"}, + }) + case "REGISTER", "VERIFY": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "FAIL", + Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"}, + }) + default: + panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command)) + } + } + } + + uc.pendingCmds = make(map[string][]pendingUpstreamCommand) +} + +func (uc *upstreamConn) sendNextPendingCommand(cmd string) { + if len(uc.pendingCmds[cmd]) == 0 { + return + } + uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg) +} + +func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { + switch msg.Command { + case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY": + // Supported + default: + panic(fmt.Errorf("Unsupported pending command %q", msg.Command)) + } + + uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{ + downstreamID: dc.id, + msg: msg, + }) + + if len(uc.pendingCmds[msg.Command]) == 1 { + uc.sendNextPendingCommand(msg.Command) + } +} + +func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) { + if len(uc.pendingCmds[cmd]) == 0 { + return nil, nil + } + + pendingCmd := uc.pendingCmds[cmd][0] + return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg +} + +func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) { + dc, msg := uc.currentPendingCommand(cmd) + + if len(uc.pendingCmds[cmd]) > 0 { + copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:]) + uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1] + } + + uc.sendNextPendingCommand(cmd) + + return dc, msg +} + +func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) { + for cmd := range uc.pendingCmds { + // We can't cancel the currently running command stored in + // uc.pendingCmds[cmd][0] + for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- { + if uc.pendingCmds[cmd][i].downstreamID == downstreamID { + uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...) + } + } + } +} + +func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) { + memberships := make(memberships, 0, 4) + i := 0 + for _, m := range uc.availableMemberships { + if i >= len(s) { + break + } + if s[i] == m.Prefix { + memberships = append(memberships, m) + i++ + } + } + return &memberships, s[i:] +} + +func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error { + var label string + if l, ok := msg.GetTag("label"); ok { + label = l + delete(msg.Tags, "label") + } + + var msgBatch *batch + if batchName, ok := msg.GetTag("batch"); ok { + b, ok := uc.batches[batchName] + if !ok { + return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName) + } + msgBatch = &b + if label == "" { + label = msgBatch.Label + } + delete(msg.Tags, "batch") + } + + var downstreamID uint64 = 0 + if label != "" { + var labelOffset uint64 + n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset) + if err == nil && n < 2 { + err = errors.New("not enough arguments") + } + if err != nil { + return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err) + } + } + + if _, ok := msg.Tags["time"]; !ok { + msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now())) + } + + switch msg.Command { + case "PING": + uc.SendMessage(ctx, &irc.Message{ + Command: "PONG", + Params: msg.Params, + }) + return nil + case "NOTICE", "PRIVMSG", "TAGMSG": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var entity, text string + if msg.Command != "TAGMSG" { + if err := parseMessageParams(msg, &entity, &text); err != nil { + return err + } + } else { + if err := parseMessageParams(msg, &entity); err != nil { + return err + } + } + + if msg.Prefix.Name == serviceNick { + uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg) + break + } + if entity == serviceNick { + uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg) + break + } + + if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message + uc.produce("", msg, nil) + } else { // regular user message + target := entity + if uc.isOurNick(target) { + target = msg.Prefix.Name + } + + ch := uc.network.channels.Value(target) + if ch != nil && msg.Command != "TAGMSG" { + if ch.Detached { + uc.handleDetachedMessage(ctx, ch, msg) + } + + highlight := uc.network.isHighlight(msg) + if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) { + uc.updateChannelAutoDetach(target) + } + } + + uc.produce(target, msg, nil) + } + case "CAP": + var subCmd string + if err := parseMessageParams(msg, nil, &subCmd); err != nil { + return err + } + subCmd = strings.ToUpper(subCmd) + subParams := msg.Params[2:] + switch subCmd { + case "LS": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := subParams[len(subParams)-1] + more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*" + + uc.handleSupportedCaps(caps) + + if more { + break // wait to receive all capabilities + } + + uc.requestCaps() + + if uc.requestSASL() { + break // we'll send CAP END after authentication is completed + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + case "ACK", "NAK": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, name := range caps { + if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil { + return err + } + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + case "NEW": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + uc.handleSupportedCaps(subParams[0]) + uc.requestCaps() + case "DEL": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, c := range caps { + delete(uc.supportedCaps, c) + delete(uc.caps, c) + } + + if uc.registered { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + } + default: + uc.logger.Printf("unhandled message: %v", msg) + } + case "AUTHENTICATE": + if uc.saslClient == nil { + return fmt.Errorf("received unexpected AUTHENTICATE message") + } + + // TODO: if a challenge is 400 bytes long, buffer it + var challengeStr string + if err := parseMessageParams(msg, &challengeStr); err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + var challenge []byte + if challengeStr != "+" { + var err error + challenge, err = base64.StdEncoding.DecodeString(challengeStr) + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + } + + var resp []byte + var err error + if !uc.saslStarted { + _, resp, err = uc.saslClient.Start() + uc.saslStarted = true + } else { + resp, err = uc.saslClient.Next(challenge) + } + if err != nil { + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + // <= instead of < because we need to send a final empty response if + // the last chunk is exactly 400 bytes long + for i := 0; i <= len(resp); i += maxSASLLength { + j := i + maxSASLLength + if j > len(resp) { + j = len(resp) + } + + chunk := resp[i:j] + + var respStr = "+" + if len(chunk) != 0 { + respStr = base64.StdEncoding.EncodeToString(chunk) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{respStr}, + }) + } + case irc.RPL_LOGGEDIN: + if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil { + return err + } + uc.logger.Printf("logged in with account %q", uc.account) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.RPL_LOGGEDOUT: + uc.account = "" + uc.logger.Printf("logged out") + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateAccount() + }) + case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED: + var info string + if err := parseMessageParams(msg, nil, &info); err != nil { + return err + } + switch msg.Command { + case irc.ERR_NICKLOCKED: + uc.logger.Printf("invalid nick used with SASL authentication: %v", info) + case irc.ERR_SASLFAIL: + uc.logger.Printf("SASL authentication failed: %v", info) + case irc.ERR_SASLTOOLONG: + uc.logger.Printf("SASL message too long: %v", info) + } + + uc.saslClient = nil + uc.saslStarted = false + + if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil { + if msg.Command == irc.RPL_SASLSUCCESS { + uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword) + } + + dc.endSASL(msg) + } + + if !uc.registered { + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + } + case "REGISTER", "VERIFY": + if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil { + if msg.Command == "REGISTER" { + var account, password string + if err := parseMessageParams(msg, nil, &account); err != nil { + return err + } + if err := parseMessageParams(cmd, nil, nil, &password); err != nil { + return err + } + uc.network.autoSaveSASLPlain(ctx, account, password) + } + + dc.SendMessage(msg) + } + case irc.RPL_WELCOME: + if err := parseMessageParams(msg, &uc.nick); err != nil { + return err + } + + uc.registered = true + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("connection registered with nick %q", uc.nick) + + if uc.network.channels.Len() > 0 { + var channels, keys []string + for _, entry := range uc.network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, ch.Name) + keys = append(keys, ch.Key) + } + + for _, msg := range join(channels, keys) { + uc.SendMessage(ctx, msg) + } + } + case irc.RPL_MYINFO: + if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil { + return err + } + case irc.RPL_ISUPPORT: + if err := parseMessageParams(msg, nil, nil); err != nil { + return err + } + + var downstreamIsupport []string + for _, token := range msg.Params[1 : len(msg.Params)-1] { + parameter := token + var negate, hasValue bool + var value string + if strings.HasPrefix(token, "-") { + negate = true + token = token[1:] + } else if i := strings.IndexByte(token, '='); i >= 0 { + parameter = token[:i] + value = token[i+1:] + hasValue = true + } + + if hasValue { + uc.isupport[parameter] = &value + } else if !negate { + uc.isupport[parameter] = nil + } else { + delete(uc.isupport, parameter) + } + + var err error + switch parameter { + case "CASEMAPPING": + casemap, ok := parseCasemappingToken(value) + if !ok { + casemap = casemapRFC1459 + } + uc.network.updateCasemapping(casemap) + uc.nickCM = uc.network.casemap(uc.nick) + uc.casemapIsSet = true + case "CHANMODES": + if !negate { + err = uc.handleChanModes(value) + } else { + uc.availableChannelModes = stdChannelModes + } + case "CHANTYPES": + if !negate { + uc.availableChannelTypes = value + } else { + uc.availableChannelTypes = stdChannelTypes + } + case "PREFIX": + if !negate { + err = uc.handleMemberships(value) + } else { + uc.availableMemberships = stdMemberships + } + } + if err != nil { + return err + } + + if passthroughIsupport[parameter] { + downstreamIsupport = append(downstreamIsupport, token) + } + } + + uc.updateMonitor() + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil { + return + } + msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport) + for _, msg := range msgs { + dc.SendMessage(msg) + } + }) + case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD: + if !uc.casemapIsSet { + // upstream did not send any CASEMAPPING token, thus + // we assume it implements the old RFCs with rfc1459. + uc.casemapIsSet = true + uc.network.updateCasemapping(casemapRFC1459) + uc.nickCM = uc.network.casemap(uc.nick) + } + + if !uc.gotMotd { + // Ignore the initial MOTD upon connection, but forward + // subsequent MOTD messages downstream + uc.gotMotd = true + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case "BATCH": + var tag string + if err := parseMessageParams(msg, &tag); err != nil { + return err + } + + if strings.HasPrefix(tag, "+") { + tag = tag[1:] + if _, ok := uc.batches[tag]; ok { + return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag) + } + var batchType string + if err := parseMessageParams(msg, nil, &batchType); err != nil { + return err + } + label := label + if label == "" && msgBatch != nil { + label = msgBatch.Label + } + uc.batches[tag] = batch{ + Type: batchType, + Params: msg.Params[2:], + Outer: msgBatch, + Label: label, + } + } else if strings.HasPrefix(tag, "-") { + tag = tag[1:] + if _, ok := uc.batches[tag]; !ok { + return fmt.Errorf("unknown BATCH reference tag: %q", tag) + } + delete(uc.batches, tag) + } else { + return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag) + } + case "NICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newNick string + if err := parseMessageParams(msg, &newNick); err != nil { + return err + } + + me := false + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick) + me = true + uc.nick = newNick + uc.nickCM = uc.network.casemap(uc.nick) + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + memberships := ch.Members.Value(msg.Prefix.Name) + if memberships != nil { + ch.Members.Delete(msg.Prefix.Name) + ch.Members.SetValue(newNick, memberships) + uc.appendLog(ch.Name, msg) + } + } + + if !me { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateNick() + }) + uc.updateMonitor() + } + case "SETNAME": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var newRealname string + if err := parseMessageParams(msg, &newRealname); err != nil { + return err + } + + // TODO: consider appending this message to logs + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname) + uc.realname = newRealname + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateRealname() + }) + } else { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case "JOIN": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("joined channel %q", ch) + members := membershipsCasemapMap{newCasemapMap(0)} + members.casemap = uc.network.casemap + uc.channels.SetValue(ch, &upstreamChannel{ + Name: ch, + conn: uc, + Members: members, + }) + uc.updateChannelAutoDetach(ch) + + uc.SendMessage(ctx, &irc.Message{ + Command: "MODE", + Params: []string{ch}, + }) + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.SetValue(msg.Prefix.Name, &memberships{}) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "PART": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channels string + if err := parseMessageParams(msg, &channels); err != nil { + return err + } + + for _, ch := range strings.Split(channels, ",") { + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("parted channel %q", ch) + uch := uc.channels.Value(ch) + if uch != nil { + uc.channels.Delete(ch) + uch.updateAutoDetach(0) + } + } else { + ch, err := uc.getChannel(ch) + if err != nil { + return err + } + ch.Members.Delete(msg.Prefix.Name) + } + + chMsg := msg.Copy() + chMsg.Params[0] = ch + uc.produce(ch, chMsg, nil) + } + case "KICK": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var channel, user string + if err := parseMessageParams(msg, &channel, &user); err != nil { + return err + } + + if uc.isOurNick(user) { + uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) + uc.channels.Delete(channel) + } else { + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + ch.Members.Delete(user) + } + + uc.produce(channel, msg, nil) + case "QUIT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + if uc.isOurNick(msg.Prefix.Name) { + uc.logger.Printf("quit") + } + + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if ch.Members.Has(msg.Prefix.Name) { + ch.Members.Delete(msg.Prefix.Name) + + uc.appendLog(ch.Name, msg) + } + } + + if msg.Prefix.Name != uc.nick { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(dc.marshalMessage(msg, uc.network)) + }) + } + case irc.RPL_TOPIC, irc.RPL_NOTOPIC: + var name, topic string + if err := parseMessageParams(msg, nil, &name, &topic); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if msg.Command == irc.RPL_TOPIC { + ch.Topic = topic + } else { + ch.Topic = "" + } + case "TOPIC": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + var name string + if err := parseMessageParams(msg, &name); err != nil { + return err + } + ch, err := uc.getChannel(name) + if err != nil { + return err + } + if len(msg.Params) > 1 { + ch.Topic = msg.Params[1] + ch.TopicWho = msg.Prefix.Copy() + ch.TopicTime = time.Now() // TODO use msg.Tags["time"] + } else { + ch.Topic = "" + } + uc.produce(ch.Name, msg, nil) + case "MODE": + var name, modeStr string + if err := parseMessageParams(msg, &name, &modeStr); err != nil { + return err + } + + if !uc.isChannel(name) { // user mode change + if name != uc.nick { + return fmt.Errorf("received MODE message for unknown nick %q", name) + } + + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + } else { // channel mode change + ch, err := uc.getChannel(name) + if err != nil { + return err + } + + needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:]) + if err != nil { + return err + } + + uc.appendLog(ch.Name, msg) + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + params := make([]string, len(msg.Params)) + params[0] = dc.marshalEntity(uc.network, name) + params[1] = modeStr + + copy(params[2:], msg.Params[2:]) + for i, modeParam := range params[2:] { + if _, ok := needMarshaling[i]; ok { + params[2+i] = dc.marshalEntity(uc.network, modeParam) + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "MODE", + Params: params, + }) + }) + } + } + case irc.RPL_UMODEIS: + if err := parseMessageParams(msg, nil); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 1 { + modeStr = msg.Params[1] + } + + uc.modes = "" + if err := uc.modes.Apply(modeStr); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + if dc.upstream() == nil { + return + } + + dc.SendMessage(msg) + }) + case irc.RPL_CHANNELMODEIS: + var channel string + if err := parseMessageParams(msg, nil, &channel); err != nil { + return err + } + modeStr := "" + if len(msg.Params) > 2 { + modeStr = msg.Params[2] + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstMode := ch.modes == nil + ch.modes = make(map[byte]string) + if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil { + return err + } + + c := uc.network.channels.Value(channel) + if firstMode && (c == nil || !c.Detached) { + modeStr, modeParams := ch.modes.Format() + + uc.forEachDownstream(func(dc *downstreamConn) { + params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr} + params = append(params, modeParams...) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_CHANNELMODEIS, + Params: params, + }) + }) + } + case rpl_creationtime: + var channel, creationTime string + if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstCreationTime := ch.creationTime == "" + ch.creationTime = creationTime + + c := uc.network.channels.Value(channel) + if firstCreationTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_creationtime, + Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime}, + }) + }) + } + case rpl_topicwhotime: + var channel, who, timeStr string + if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil { + return err + } + + ch, err := uc.getChannel(channel) + if err != nil { + return err + } + + firstTopicWhoTime := ch.TopicWho == nil + ch.TopicWho = irc.ParsePrefix(who) + sec, err := strconv.ParseInt(timeStr, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse topic time: %v", err) + } + ch.TopicTime = time.Unix(sec, 0) + + c := uc.network.channels.Value(channel) + if firstTopicWhoTime && (c == nil || !c.Detached) { + uc.forEachDownstream(func(dc *downstreamConn) { + topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: rpl_topicwhotime, + Params: []string{ + dc.nick, + dc.marshalEntity(uc.network, ch.Name), + topicWho.String(), + timeStr, + }, + }) + }) + } + case irc.RPL_LIST: + var channel, clients, topic string + if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LIST, + Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic}, + }) + case irc.RPL_LISTEND: + dc, cmd := uc.dequeueCommand("LIST") + if cmd == nil { + return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST") + } else if dc == nil { + return nil + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "End of /LIST"}, + }) + case irc.RPL_NAMREPLY: + var name, statusStr, members string + if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + members := splitSpace(members) + for i, member := range members { + memberships, nick := uc.parseMembershipPrefix(member) + members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick) + } + memberStr := strings.Join(members, " ") + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_NAMREPLY, + Params: []string{dc.nick, statusStr, channel, memberStr}, + }) + }) + return nil + } + + status, err := parseChannelStatus(statusStr) + if err != nil { + return err + } + ch.Status = status + + for _, s := range splitSpace(members) { + memberships, nick := uc.parseMembershipPrefix(s) + ch.Members.SetValue(nick, memberships) + } + case irc.RPL_ENDOFNAMES: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + ch := uc.channels.Value(name) + if ch == nil { + // NAMES on a channel we have not joined, forward to downstream + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, name) + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFNAMES, + Params: []string{dc.nick, channel, "End of /NAMES list"}, + }) + }) + return nil + } + + if ch.complete { + return fmt.Errorf("received unexpected RPL_ENDOFNAMES") + } + ch.complete = true + + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { + uc.forEachDownstream(func(dc *downstreamConn) { + forwardChannel(ctx, dc, ch) + }) + } + case irc.RPL_WHOREPLY: + var channel, username, host, server, nick, flags, trailing string + if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil { + return err + } + + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO") + } else if dc == nil { + return nil + } + + if channel != "*" { + channel = dc.marshalEntity(uc.network, channel) + } + nick = dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOREPLY, + Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing}, + }) + case rpl_whospcrpl: + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO") + } else if dc == nil { + return nil + } + + // Only supported in single-upstream mode, so forward as-is + dc.SendMessage(msg) + case irc.RPL_ENDOFWHO: + var name string + if err := parseMessageParams(msg, nil, &name); err != nil { + return err + } + + dc, cmd := uc.dequeueCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO") + } else if dc == nil { + return nil + } + + mask := "*" + if len(cmd.Params) > 0 { + mask = cmd.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "End of /WHO list"}, + }) + case irc.RPL_WHOISUSER: + var nick, username, host, realname string + if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISUSER, + Params: []string{dc.nick, nick, username, host, "*", realname}, + }) + }) + case irc.RPL_WHOISSERVER: + var nick, server, serverInfo string + if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISSERVER, + Params: []string{dc.nick, nick, server, serverInfo}, + }) + }) + case irc.RPL_WHOISOPERATOR: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISOPERATOR, + Params: []string{dc.nick, nick, "is an IRC operator"}, + }) + }) + case irc.RPL_WHOISIDLE: + var nick string + if err := parseMessageParams(msg, nil, &nick, nil); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + params := []string{dc.nick, nick} + params = append(params, msg.Params[2:]...) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISIDLE, + Params: params, + }) + }) + case irc.RPL_WHOISCHANNELS: + var nick, channelList string + if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil { + return err + } + channels := splitSpace(channelList) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + channelList := make([]string, len(channels)) + for i, channel := range channels { + prefix, channel := uc.parseMembershipPrefix(channel) + channel = dc.marshalEntity(uc.network, channel) + channelList[i] = prefix.Format(dc) + channel + } + channels := strings.Join(channelList, " ") + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOISCHANNELS, + Params: []string{dc.nick, nick, channels}, + }) + }) + case irc.RPL_ENDOFWHOIS: + var nick string + if err := parseMessageParams(msg, nil, &nick); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + nick := dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHOIS, + Params: []string{dc.nick, nick, "End of /WHOIS list"}, + }) + }) + case "INVITE": + var nick, channel string + if err := parseMessageParams(msg, &nick, &channel); err != nil { + return err + } + + weAreInvited := uc.isOurNick(nick) + + uc.forEachDownstream(func(dc *downstreamConn) { + if !weAreInvited && !dc.caps["invite-notify"] { + return + } + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "INVITE", + Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_INVITING: + var nick, channel string + if err := parseMessageParams(msg, nil, &nick, &channel); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_INVITING, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)}, + }) + }) + case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: + var targetsStr string + if err := parseMessageParams(msg, nil, &targetsStr); err != nil { + return err + } + targets := strings.Split(targetsStr, ",") + + online := msg.Command == irc.RPL_MONONLINE + for _, target := range targets { + prefix := irc.ParsePrefix(target) + uc.monitored.SetValue(prefix.Name, online) + } + + // Check if the nick we want is now free + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if !online && uc.nickCM != wantNickCM { + found := false + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if uc.network.casemap(prefix.Name) == wantNickCM { + found = true + break + } + } + if found { + uc.logger.Printf("desired nick %q is now available", wantNick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{wantNick}, + }) + } + } + + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + prefix := irc.ParsePrefix(target) + if dc.monitored.Has(prefix.Name) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, target}, + }) + } + } + }) + case irc.ERR_MONLISTFULL: + var limit, targetsStr string + if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil { + return err + } + + targets := strings.Split(targetsStr, ",") + uc.forEachDownstream(func(dc *downstreamConn) { + for _, target := range targets { + if dc.monitored.Has(target) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, limit, target}, + }) + } + } + }) + case irc.RPL_AWAY: + var nick, reason string + if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { + return err + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_AWAY, + Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason}, + }) + }) + case "AWAY", "ACCOUNT": + if msg.Prefix == nil { + return fmt.Errorf("expected a prefix") + } + + uc.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST: + var channel, mask string + if err := parseMessageParams(msg, nil, &channel, &mask); err != nil { + return err + } + var addNick, addTime string + if len(msg.Params) >= 5 { + addNick = msg.Params[3] + addTime = msg.Params[4] + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + channel := dc.marshalEntity(uc.network, channel) + + var params []string + if addNick != "" && addTime != "" { + addNick := dc.marshalEntity(uc.network, addNick) + params = []string{dc.nick, channel, mask, addNick, addTime} + } else { + params = []string{dc.nick, channel, mask} + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST: + var channel, trailing string + if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil { + return err + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + upstreamChannel := dc.marshalEntity(uc.network, channel) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, upstreamChannel, trailing}, + }) + }) + case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN: + var command, reason string + if err := parseMessageParams(msg, nil, &command, &reason); err != nil { + return err + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, command, reason}, + }) + }) + case "FAIL": + var command, code string + if err := parseMessageParams(msg, &command, &code); err != nil { + return err + } + + if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" { + return registrationError{msg} + } + + if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 { + downstreamID = dc.id + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case "ACK": + // Ignore + case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: + // Ignore + case irc.RPL_YOURHOST, irc.RPL_CREATED: + // Ignore + case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: + fallthrough + case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: + fallthrough + case rpl_localusers, rpl_globalusers: + fallthrough + case irc.RPL_MOTDSTART, irc.RPL_MOTD: + // Ignore these messages if they're part of the initial registration + // message burst. Forward them if the user explicitly asked for them. + if !uc.gotMotd { + return nil + } + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: msg.Params, + }) + }) + case irc.RPL_LISTSTART: + // Ignore + case "ERROR": + var text string + if err := parseMessageParams(msg, &text); err != nil { + return err + } + return fmt.Errorf("fatal server error: %v", text) + case irc.ERR_NICKNAMEINUSE: + // At this point, we haven't received ISUPPORT so we don't know the + // maximum nickname length or whether the server supports MONITOR. Many + // servers have NICKLEN=30 so let's just use that. + if !uc.registered && len(uc.nick)+1 < 30 { + uc.nick = uc.nick + "_" + uc.nickCM = uc.network.casemap(uc.nick) + uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick) + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + return nil + } + fallthrough + case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP: + if !uc.registered { + return registrationError{msg} + } + fallthrough + default: + uc.logger.Printf("unhandled message: %v", msg) + + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + // best effort marshaling for unknown messages, replies and errors: + // most numerics start with the user nick, marshal it if that's the case + // otherwise, conservately keep the params without marshaling + params := msg.Params + if _, err := strconv.Atoi(msg.Command); err == nil { // numeric + if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) { + params[0] = dc.nick + } + } + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: params, + }) + }) + } + return nil +} + +func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) { + if uc.network.detachedMessageNeedsRelay(ch, msg) { + uc.forEachDownstream(func(dc *downstreamConn) { + dc.relayDetachedMessage(uc.network, msg) + }) + } + if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { + uc.network.attach(ctx, ch) + if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) + } + } +} + +func (uc *upstreamConn) handleChanModes(s string) error { + parts := strings.SplitN(s, ",", 5) + if len(parts) < 4 { + return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s) + } + modes := make(map[byte]channelModeType) + for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} { + for j := 0; j < len(parts[i]); j++ { + mode := parts[i][j] + modes[mode] = mt + } + } + uc.availableChannelModes = modes + return nil +} + +func (uc *upstreamConn) handleMemberships(s string) error { + if s == "" { + uc.availableMemberships = nil + return nil + } + + if s[0] != '(' { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + sep := strings.IndexByte(s, ')') + if sep < 0 || len(s) != sep*2 { + return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s) + } + memberships := make([]membership, len(s)/2-1) + for i := range memberships { + memberships[i] = membership{ + Mode: s[i+1], + Prefix: s[sep+i+1], + } + } + uc.availableMemberships = memberships + return nil +} + +func (uc *upstreamConn) handleSupportedCaps(capsStr string) { + caps := strings.Fields(capsStr) + for _, s := range caps { + kv := strings.SplitN(s, "=", 2) + k := strings.ToLower(kv[0]) + var v string + if len(kv) == 2 { + v = kv[1] + } + uc.supportedCaps[k] = v + } +} + +func (uc *upstreamConn) requestCaps() { + var requestCaps []string + for c := range permanentUpstreamCaps { + if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] { + requestCaps = append(requestCaps, c) + } + } + + if len(requestCaps) == 0 { + return + } + + uc.SendMessage(context.TODO(), &irc.Message{ + Command: "CAP", + Params: []string{"REQ", strings.Join(requestCaps, " ")}, + }) +} + +func (uc *upstreamConn) supportsSASL(mech string) bool { + v, ok := uc.supportedCaps["sasl"] + if !ok { + return false + } + + if v == "" { + return true + } + + mechanisms := strings.Split(v, ",") + for _, mech := range mechanisms { + if strings.EqualFold(mech, mech) { + return true + } + } + return false +} + +func (uc *upstreamConn) requestSASL() bool { + if uc.network.SASL.Mechanism == "" { + return false + } + return uc.supportsSASL(uc.network.SASL.Mechanism) +} + +func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error { + uc.caps[name] = ok + + switch name { + case "sasl": + if !uc.requestSASL() { + return nil + } + if !ok { + uc.logger.Printf("server refused to acknowledge the SASL capability") + return nil + } + + auth := &uc.network.SASL + switch auth.Mechanism { + case "PLAIN": + uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) + uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password) + case "EXTERNAL": + uc.logger.Printf("starting SASL EXTERNAL authentication") + uc.saslClient = sasl.NewExternalClient("") + default: + return fmt.Errorf("unsupported SASL mechanism %q", name) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "AUTHENTICATE", + Params: []string{auth.Mechanism}, + }) + default: + if permanentUpstreamCaps[name] { + break + } + uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name) + } + return nil +} + +func splitSpace(s string) []string { + return strings.FieldsFunc(s, func(r rune) bool { + return r == ' ' + }) +} + +func (uc *upstreamConn) register(ctx context.Context) { + uc.nick = GetNick(&uc.user.User, &uc.network.Network) + uc.nickCM = uc.network.casemap(uc.nick) + uc.username = GetUsername(&uc.user.User, &uc.network.Network) + uc.realname = GetRealname(&uc.user.User, &uc.network.Network) + + uc.SendMessage(ctx, &irc.Message{ + Command: "CAP", + Params: []string{"LS", "302"}, + }) + + if uc.network.Pass != "" { + uc.SendMessage(ctx, &irc.Message{ + Command: "PASS", + Params: []string{uc.network.Pass}, + }) + } + + uc.SendMessage(ctx, &irc.Message{ + Command: "NICK", + Params: []string{uc.nick}, + }) + uc.SendMessage(ctx, &irc.Message{ + Command: "USER", + Params: []string{uc.username, "0", "*", uc.realname}, + }) +} + +func (uc *upstreamConn) ReadMessage() (*irc.Message, error) { + msg, err := uc.conn.ReadMessage() + if err != nil { + return nil, err + } + return msg, nil +} + +func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error { + for !uc.registered { + msg, err := uc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read message: %v", err) + } + + if err := uc.handleMessage(ctx, msg); err != nil { + if _, ok := err.(registrationError); ok { + return err + } else { + msg.Tags = nil // prevent message tags from cluttering logs + return fmt.Errorf("failed to handle message %q: %v", msg, err) + } + } + } + + for _, command := range uc.network.ConnectCommands { + m, err := irc.ParseMessage(command) + if err != nil { + uc.logger.Printf("failed to parse connect command %q: %v", command, err) + } else { + uc.SendMessage(ctx, m) + } + } + + return nil +} + +func (uc *upstreamConn) readMessages(ch chan<- event) error { + for { + msg, err := uc.ReadMessage() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return fmt.Errorf("failed to read IRC command: %v", err) + } + + ch <- eventUpstreamMessage{msg, uc} + } + + return nil +} + +func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { + if !uc.caps["message-tags"] { + msg = msg.Copy() + msg.Tags = nil + } + + uc.conn.SendMessage(ctx, msg) +} + +func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { + if uc.caps["labeled-response"] { + if msg.Tags == nil { + msg.Tags = make(map[string]irc.TagValue) + } + msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID)) + uc.nextLabelID++ + } + uc.SendMessage(ctx, msg) +} + +// appendLog appends a message to the log file. +// +// The internal message ID is returned. If the message isn't recorded in the +// log file, an empty string is returned. +func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) { + if uc.user.msgStore == nil { + return "" + } + + // Don't store messages with a server mask target + if strings.HasPrefix(entity, "$") { + return "" + } + + entityCM := uc.network.casemap(entity) + if entityCM == "nickserv" { + // The messages sent/received from NickServ may contain + // security-related information (like passwords). Don't store these. + return "" + } + + if !uc.network.delivered.HasTarget(entity) { + // This is the first message we receive from this target. Save the last + // message ID in delivery receipts, so that we can send the new message + // in the backlog if an offline client reconnects. + lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now()) + if err != nil { + uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) + return "" + } + + uc.network.delivered.ForEachClient(func(clientName string) { + uc.network.delivered.StoreID(entity, clientName, lastID) + }) + } + + msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg) + if err != nil { + uc.logger.Printf("failed to append message to store: %v", err) + return "" + } + + return msgID +} + +// produce appends a message to the logs and forwards it to connected downstream +// connections. +// +// If origin is not nil and origin doesn't support echo-message, the message is +// forwarded to all connections except origin. +func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) { + var msgID string + if target != "" { + msgID = uc.appendLog(target, msg) + } + + // Don't forward messages if it's a detached channel + ch := uc.network.channels.Value(target) + detached := ch != nil && ch.Detached + + uc.forEachDownstream(func(dc *downstreamConn) { + if !detached && (dc != origin || dc.caps["echo-message"]) { + dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID) + } else { + dc.advanceMessageWithID(msg, msgID) + } + }) +} + +func (uc *upstreamConn) updateAway() { + ctx := context.TODO() + + away := true + uc.forEachDownstream(func(*downstreamConn) { + away = false + }) + if away == uc.away { + return + } + if away { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + Params: []string{"Auto away"}, + }) + } else { + uc.SendMessage(ctx, &irc.Message{ + Command: "AWAY", + }) + } + uc.away = away +} + +func (uc *upstreamConn) updateChannelAutoDetach(name string) { + uch := uc.channels.Value(name) + if uch == nil { + return + } + ch := uc.network.channels.Value(name) + if ch == nil || ch.Detached { + return + } + uch.updateAutoDetach(ch.DetachAfter) +} + +func (uc *upstreamConn) updateMonitor() { + if _, ok := uc.isupport["MONITOR"]; !ok { + return + } + + ctx := context.TODO() + + add := make(map[string]struct{}) + var addList []string + seen := make(map[string]struct{}) + uc.forEachDownstream(func(dc *downstreamConn) { + for targetCM := range dc.monitored.innerMap { + if !uc.monitored.Has(targetCM) { + if _, ok := add[targetCM]; !ok { + addList = append(addList, targetCM) + add[targetCM] = struct{}{} + } + } else { + seen[targetCM] = struct{}{} + } + } + }) + + wantNick := GetNick(&uc.user.User, &uc.network.Network) + wantNickCM := uc.network.casemap(wantNick) + if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) { + addList = append(addList, wantNickCM) + add[wantNickCM] = struct{}{} + } + + removeAll := true + var removeList []string + for targetCM, entry := range uc.monitored.innerMap { + if _, ok := seen[targetCM]; ok { + removeAll = false + } else { + removeList = append(removeList, entry.originalKey) + } + } + + // TODO: better handle the case where len(uc.monitored) + len(addList) + // exceeds the limit, probably by immediately sending ERR_MONLISTFULL? + + if removeAll && len(addList) == 0 && len(removeList) > 0 { + // Optimization when the last MONITOR-aware downstream disconnects + uc.SendMessage(ctx, &irc.Message{ + Command: "MONITOR", + Params: []string{"C"}, + }) + } else { + msgs := generateMonitor("-", removeList) + msgs = append(msgs, generateMonitor("+", addList)...) + for _, msg := range msgs { + uc.SendMessage(ctx, msg) + } + } + + for _, target := range removeList { + uc.monitored.Delete(target) + } +} diff --git a/trunk/user.go b/trunk/user.go new file mode 100644 index 0000000..0b523ce --- /dev/null +++ b/trunk/user.go @@ -0,0 +1,1068 @@ +package suika + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "math/big" + "net" + "sort" + "strings" + "time" + + "gopkg.in/irc.v3" +) + +type event interface{} + +type eventUpstreamMessage struct { + msg *irc.Message + uc *upstreamConn +} + +type eventUpstreamConnectionError struct { + net *network + err error +} + +type eventUpstreamConnected struct { + uc *upstreamConn +} + +type eventUpstreamDisconnected struct { + uc *upstreamConn +} + +type eventUpstreamError struct { + uc *upstreamConn + err error +} + +type eventDownstreamMessage struct { + msg *irc.Message + dc *downstreamConn +} + +type eventDownstreamConnected struct { + dc *downstreamConn +} + +type eventDownstreamDisconnected struct { + dc *downstreamConn +} + +type eventChannelDetach struct { + uc *upstreamConn + name string +} + +type eventBroadcast struct { + msg *irc.Message +} + +type eventStop struct{} + +type eventUserUpdate struct { + password *string + admin *bool + done chan error +} + +type deliveredClientMap map[string]string // client name -> msg ID + +type deliveredStore struct { + m deliveredCasemapMap +} + +func newDeliveredStore() deliveredStore { + return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}} +} + +func (ds deliveredStore) HasTarget(target string) bool { + return ds.m.Value(target) != nil +} + +func (ds deliveredStore) LoadID(target, clientName string) string { + clients := ds.m.Value(target) + if clients == nil { + return "" + } + return clients[clientName] +} + +func (ds deliveredStore) StoreID(target, clientName, msgID string) { + clients := ds.m.Value(target) + if clients == nil { + clients = make(deliveredClientMap) + ds.m.SetValue(target, clients) + } + clients[clientName] = msgID +} + +func (ds deliveredStore) ForEachTarget(f func(target string)) { + for _, entry := range ds.m.innerMap { + f(entry.originalKey) + } +} + +func (ds deliveredStore) ForEachClient(f func(clientName string)) { + clients := make(map[string]struct{}) + for _, entry := range ds.m.innerMap { + delivered := entry.value.(deliveredClientMap) + for clientName := range delivered { + clients[clientName] = struct{}{} + } + } + + for clientName := range clients { + f(clientName) + } +} + +type network struct { + Network + user *user + logger Logger + stopped chan struct{} + + conn *upstreamConn + channels channelCasemapMap + delivered deliveredStore + lastError error + casemap casemapping +} + +func newNetwork(user *user, record *Network, channels []Channel) *network { + logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} + + m := channelCasemapMap{newCasemapMap(0)} + for _, ch := range channels { + ch := ch + m.SetValue(ch.Name, &ch) + } + + return &network{ + Network: *record, + user: user, + logger: logger, + stopped: make(chan struct{}), + channels: m, + delivered: newDeliveredStore(), + casemap: casemapRFC1459, + } +} + +func (net *network) forEachDownstream(f func(*downstreamConn)) { + net.user.forEachDownstream(func(dc *downstreamConn) { + if dc.network == nil && !dc.isMultiUpstream { + return + } + if dc.network != nil && dc.network != net { + return + } + f(dc) + }) +} + +func (net *network) isStopped() bool { + select { + case <-net.stopped: + return true + default: + return false + } +} + +func userIdent(u *User) string { + // The ident is a string we will send to upstream servers in clear-text. + // For privacy reasons, make sure it doesn't expose any meaningful user + // metadata. We just use the base64-encoded hashed ID, so that people don't + // start relying on the string being an integer or following a pattern. + var b [64]byte + binary.LittleEndian.PutUint64(b[:], uint64(u.ID)) + h := sha256.Sum256(b[:]) + return hex.EncodeToString(h[:16]) +} + +func (net *network) run() { + if !net.Enabled { + return + } + + var lastTry time.Time + backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter) + for { + if net.isStopped() { + return + } + + delay := backoff.Next() - time.Now().Sub(lastTry) + if delay > 0 { + net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr) + time.Sleep(delay) + } + lastTry = time.Now() + + + uc, err := connectToUpstream(context.TODO(), net) + if err != nil { + net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)} + continue + } + + uc.register(context.TODO()) + if err := uc.runUntilRegistered(context.TODO()); err != nil { + text := err.Error() + temp := true + if regErr, ok := err.(registrationError); ok { + text = regErr.Reason() + temp = regErr.Temporary() + } + uc.logger.Printf("failed to register: %v", text) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)} + uc.Close() + if !temp { + return + } + continue + } + + // TODO: this is racy with net.stopped. If the network is stopped + // before the user goroutine receives eventUpstreamConnected, the + // connection won't be closed. + net.user.events <- eventUpstreamConnected{uc} + if err := uc.readMessages(net.user.events); err != nil { + uc.logger.Printf("failed to handle messages: %v", err) + net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)} + } + uc.Close() + net.user.events <- eventUpstreamDisconnected{uc} + + backoff.Reset() + } +} + +func (net *network) stop() { + if !net.isStopped() { + close(net.stopped) + } + + if net.conn != nil { + net.conn.Close() + } +} + +func (net *network) detach(ch *Channel) { + if ch.Detached { + return + } + + net.logger.Printf("detaching channel %q", ch.Name) + + ch.Detached = true + + if net.user.msgStore != nil { + nameCM := net.casemap(ch.Name) + lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now()) + if err != nil { + net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err) + } + ch.DetachedInternalMsgID = lastID + } + + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "PART", + Params: []string{dc.marshalEntity(net, ch.Name), "Detach"}, + }) + }) +} + +func (net *network) attach(ctx context.Context, ch *Channel) { + if !ch.Detached { + return + } + + net.logger.Printf("attaching channel %q", ch.Name) + + detachedMsgID := ch.DetachedInternalMsgID + ch.Detached = false + ch.DetachedInternalMsgID = "" + + var uch *upstreamChannel + if net.conn != nil { + uch = net.conn.channels.Value(ch.Name) + + net.conn.updateChannelAutoDetach(ch.Name) + } + + net.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.prefix(), + Command: "JOIN", + Params: []string{dc.marshalEntity(net, ch.Name)}, + }) + + if uch != nil { + forwardChannel(ctx, dc, uch) + } + + if detachedMsgID != "" { + dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID) + } + }) +} + +func (net *network) deleteChannel(ctx context.Context, name string) error { + ch := net.channels.Value(name) + if ch == nil { + return fmt.Errorf("unknown channel %q", name) + } + if net.conn != nil { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { + uch.updateAutoDetach(0) + } + } + + if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil { + return err + } + net.channels.Delete(name) + return nil +} + +func (net *network) updateCasemapping(newCasemap casemapping) { + net.casemap = newCasemap + net.channels.SetCasemapping(newCasemap) + net.delivered.m.SetCasemapping(newCasemap) + if uc := net.conn; uc != nil { + uc.channels.SetCasemapping(newCasemap) + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.Members.SetCasemapping(newCasemap) + } + uc.monitored.SetCasemapping(newCasemap) + } + net.forEachDownstream(func(dc *downstreamConn) { + dc.monitored.SetCasemapping(newCasemap) + }) +} + +func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) { + if !net.user.hasPersistentMsgStore() { + return + } + + var receipts []DeliveryReceipt + net.delivered.ForEachTarget(func(target string) { + msgID := net.delivered.LoadID(target, clientName) + if msgID == "" { + return + } + receipts = append(receipts, DeliveryReceipt{ + Target: target, + InternalMsgID: msgID, + }) + }) + + if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil { + net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err) + } +} + +func (net *network) isHighlight(msg *irc.Message) bool { + if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" { + return false + } + + text := msg.Params[1] + + nick := net.Nick + if net.conn != nil { + nick = net.conn.nick + } + + // TODO: use case-mapping aware comparison here + return msg.Prefix.Name != nick && isHighlight(text, nick) +} + +func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool { + highlight := net.isHighlight(msg) + return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight) +} + +func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) { + // User may have e.g. EXTERNAL mechanism configured. We do not want to + // automatically erase the key pair or any other credentials. + if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" { + return + } + + net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username) + net.SASL.Mechanism = "PLAIN" + net.SASL.Plain.Username = username + net.SASL.Plain.Password = password + if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil { + net.logger.Printf("failed to save SASL PLAIN credentials: %v", err) + } +} + +type user struct { + User + srv *Server + logger Logger + + events chan event + done chan struct{} + + networks []*network + downstreamConns []*downstreamConn + msgStore messageStore +} + +func newUser(srv *Server, record *User) *user { + logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} + + var msgStore messageStore + if logPath := srv.Config().LogPath; logPath != "" { + msgStore = newFSMessageStore(logPath, record) + } else { + msgStore = newMemoryMessageStore() + } + + return &user{ + User: *record, + srv: srv, + logger: logger, + events: make(chan event, 64), + done: make(chan struct{}), + msgStore: msgStore, + } +} + +func (u *user) forEachUpstream(f func(uc *upstreamConn)) { + for _, network := range u.networks { + if network.conn == nil { + continue + } + f(network.conn) + } +} + +func (u *user) forEachDownstream(f func(dc *downstreamConn)) { + for _, dc := range u.downstreamConns { + f(dc) + } +} + +func (u *user) getNetwork(name string) *network { + for _, network := range u.networks { + if network.Addr == name { + return network + } + if network.Name != "" && network.Name == name { + return network + } + } + return nil +} + +func (u *user) getNetworkByID(id int64) *network { + for _, net := range u.networks { + if net.ID == id { + return net + } + } + return nil +} + +func (u *user) run() { + defer func() { + if u.msgStore != nil { + if err := u.msgStore.Close(); err != nil { + u.logger.Printf("failed to close message store for user %q: %v", u.Username, err) + } + } + close(u.done) + }() + + networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID) + if err != nil { + u.logger.Printf("failed to list networks for user %q: %v", u.Username, err) + return + } + + sort.Slice(networks, func(i, j int) bool { + return networks[i].ID < networks[j].ID + }) + + for _, record := range networks { + record := record + channels, err := u.srv.db.ListChannels(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err) + continue + } + + network := newNetwork(u, &record, channels) + u.networks = append(u.networks, network) + + if u.hasPersistentMsgStore() { + receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID) + if err != nil { + u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err) + return + } + + for _, rcpt := range receipts { + network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID) + } + } + + go network.run() + } + + for e := range u.events { + switch e := e.(type) { + case eventUpstreamConnected: + uc := e.uc + + uc.network.conn = uc + + uc.updateAway() + uc.updateMonitor() + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) + } + + dc.updateNick() + dc.updateRealname() + dc.updateAccount() + }) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=connected"}, + }) + } + }) + uc.network.lastError = nil + case eventUpstreamDisconnected: + u.handleUpstreamDisconnected(e.uc) + case eventUpstreamConnectionError: + net := e.net + + stopped := false + select { + case <-net.stopped: + stopped = true + default: + } + + if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) { + net.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err)) + }) + } + net.lastError = e.err + case eventUpstreamError: + uc := e.uc + + uc.forEachDownstream(func(dc *downstreamConn) { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err)) + }) + uc.network.lastError = e.err + case eventUpstreamMessage: + msg, uc := e.msg, e.uc + if uc.isClosed() { + uc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + if err := uc.handleMessage(context.TODO(), msg); err != nil { + uc.logger.Printf("failed to handle message %q: %v", msg, err) + } + case eventChannelDetach: + uc, name := e.uc, e.name + c := uc.network.channels.Value(name) + if c == nil || c.Detached { + continue + } + uc.network.detach(c) + if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil { + u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err) + } + case eventDownstreamConnected: + dc := e.dc + + if dc.network != nil { + dc.monitored.SetCasemapping(dc.network.casemap) + } + + if err := dc.welcome(context.TODO()); err != nil { + dc.logger.Printf("failed to handle new registered connection: %v", err) + break + } + + u.downstreamConns = append(u.downstreamConns, dc) + + dc.forEachNetwork(func(network *network) { + if network.lastError != nil { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError)) + } + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.updateAway() + }) + case eventDownstreamDisconnected: + dc := e.dc + + for i := range u.downstreamConns { + if u.downstreamConns[i] == dc { + u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) + break + } + } + + dc.forEachNetwork(func(net *network) { + net.storeClientDeliveryReceipts(context.TODO(), dc.clientName) + }) + + u.forEachUpstream(func(uc *upstreamConn) { + uc.cancelPendingCommandsByDownstreamID(dc.id) + uc.updateAway() + uc.updateMonitor() + }) + case eventDownstreamMessage: + msg, dc := e.msg, e.dc + if dc.isClosed() { + dc.logger.Printf("ignoring message on closed connection: %v", msg) + break + } + err := dc.handleMessage(context.TODO(), msg) + if ircErr, ok := err.(ircError); ok { + ircErr.Message.Prefix = dc.srv.prefix() + dc.SendMessage(ircErr.Message) + } else if err != nil { + dc.logger.Printf("failed to handle message %q: %v", msg, err) + dc.Close() + } + case eventBroadcast: + msg := e.msg + u.forEachDownstream(func(dc *downstreamConn) { + dc.SendMessage(msg) + }) + case eventUserUpdate: + // copy the user record because we'll mutate it + record := u.User + + if e.password != nil { + record.Password = *e.password + } + if e.admin != nil { + record.Admin = *e.admin + } + + e.done <- u.updateUser(context.TODO(), &record) + + // If the password was updated, kill all downstream connections to + // force them to re-authenticate with the new credentials. + if e.password != nil { + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + } + case eventStop: + u.forEachDownstream(func(dc *downstreamConn) { + dc.Close() + }) + for _, n := range u.networks { + n.stop() + + n.delivered.ForEachClient(func(clientName string) { + n.storeClientDeliveryReceipts(context.TODO(), clientName) + }) + } + return + default: + panic(fmt.Sprintf("received unknown event type: %T", e)) + } + } +} + +func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { + uc.network.conn = nil + + uc.abortPendingCommands() + + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.updateAutoDetach(0) + } + + netIDStr := fmt.Sprintf("%v", uc.network.ID) + uc.forEachDownstream(func(dc *downstreamConn) { + dc.updateSupportedCaps() + }) + + // If the network has been removed, don't send a state change notification + found := false + for _, net := range u.networks { + if net == uc.network { + found = true + break + } + } + if !found { + return + } + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", netIDStr, "state=disconnected"}, + }) + } + }) + + if uc.network.lastError == nil { + uc.forEachDownstream(func(dc *downstreamConn) { + if !dc.caps["soju.im/bouncer-networks"] { + sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName())) + } + }) + } +} + +func (u *user) addNetwork(network *network) { + u.networks = append(u.networks, network) + + sort.Slice(u.networks, func(i, j int) bool { + return u.networks[i].ID < u.networks[j].ID + }) + + go network.run() +} + +func (u *user) removeNetwork(network *network) { + network.stop() + + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.Close() + } + }) + + for i, net := range u.networks { + if net == network { + u.networks = append(u.networks[:i], u.networks[i+1:]...) + return + } + } + + panic("tried to remove a non-existing network") +} + +func (u *user) checkNetwork(record *Network) error { + url, err := record.URL() + if err != nil { + return err + } + if url.User != nil { + return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme) + } + if url.RawQuery != "" { + return fmt.Errorf("%v:// URL must not have query values", url.Scheme) + } + if url.Fragment != "" { + return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme) + } + switch url.Scheme { + case "ircs", "irc": + if url.Host == "" { + return fmt.Errorf("%v:// URL must have a host", url.Scheme) + } + if url.Path != "" { + return fmt.Errorf("%v:// URL must not have a path", url.Scheme) + } + case "irc+unix", "unix": + if url.Host != "" { + return fmt.Errorf("%v:// URL must not have a host", url.Scheme) + } + if url.Path == "" { + return fmt.Errorf("%v:// URL must have a path", url.Scheme) + } + default: + return fmt.Errorf("unknown URL scheme %q", url.Scheme) + } + + if record.GetName() == "" { + return fmt.Errorf("network name cannot be empty") + } + if strings.HasPrefix(record.GetName(), "-") { + // Can be mixed up with flags when sending commands to the service + return fmt.Errorf("network name cannot start with a dash character") + } + + for _, net := range u.networks { + if net.GetName() == record.GetName() && net.ID != record.ID { + return fmt.Errorf("a network with the name %q already exists", record.GetName()) + } + } + + return nil +} + +func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID != 0 { + panic("tried creating an already-existing network") + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max { + return nil, fmt.Errorf("maximum number of networks reached") + } + + network := newNetwork(u, record, nil) + err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network) + if err != nil { + return nil, err + } + + u.addNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + attrs := getNetworkAttrs(network) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return network, nil +} + +func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) { + if record.ID == 0 { + panic("tried updating a new network") + } + + // If the realname is reset to the default, just wipe the per-network + // setting + if record.Realname == u.Realname { + record.Realname = "" + } + + if err := u.checkNetwork(record); err != nil { + return nil, err + } + + network := u.getNetworkByID(record.ID) + if network == nil { + panic("tried updating a non-existing network") + } + + if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil { + return nil, err + } + + // Most network changes require us to re-connect to the upstream server + + channels := make([]Channel, 0, network.channels.Len()) + for _, entry := range network.channels.innerMap { + ch := entry.value.(*Channel) + channels = append(channels, *ch) + } + + updatedNetwork := newNetwork(u, record, channels) + + // If we're currently connected, disconnect and perform the necessary + // bookkeeping + if network.conn != nil { + network.stop() + // Note: this will set network.conn to nil + u.handleUpstreamDisconnected(network.conn) + } + + // Patch downstream connections to use our fresh updated network + u.forEachDownstream(func(dc *downstreamConn) { + if dc.network != nil && dc.network == network { + dc.network = updatedNetwork + } + }) + + // We need to remove the network after patching downstream connections, + // otherwise they'll get closed + u.removeNetwork(network) + + // The filesystem message store needs to be notified whenever the network + // is renamed + fsMsgStore, isFS := u.msgStore.(*fsMessageStore) + if isFS && updatedNetwork.GetName() != network.GetName() { + if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil { + network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err) + } + } + + // This will re-connect to the upstream server + u.addNetwork(updatedNetwork) + + // TODO: only broadcast attributes that have changed + idStr := fmt.Sprintf("%v", updatedNetwork.ID) + attrs := getNetworkAttrs(updatedNetwork) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, attrs.String()}, + }) + } + }) + + return updatedNetwork, nil +} + +func (u *user) deleteNetwork(ctx context.Context, id int64) error { + network := u.getNetworkByID(id) + if network == nil { + panic("tried deleting a non-existing network") + } + + if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil { + return err + } + + u.removeNetwork(network) + + idStr := fmt.Sprintf("%v", network.ID) + u.forEachDownstream(func(dc *downstreamConn) { + if dc.caps["soju.im/bouncer-networks-notify"] { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "BOUNCER", + Params: []string{"NETWORK", idStr, "*"}, + }) + } + }) + + return nil +} + +func (u *user) updateUser(ctx context.Context, record *User) error { + if u.ID != record.ID { + panic("ID mismatch when updating user") + } + + realnameUpdated := u.Realname != record.Realname + if err := u.srv.db.StoreUser(ctx, record); err != nil { + return fmt.Errorf("failed to update user %q: %v", u.Username, err) + } + u.User = *record + + if realnameUpdated { + // Re-connect to networks which use the default realname + var needUpdate []Network + for _, net := range u.networks { + if net.Realname == "" { + needUpdate = append(needUpdate, net.Network) + } + } + + var netErr error + for _, net := range needUpdate { + if _, err := u.updateNetwork(ctx, &net); err != nil { + netErr = err + } + } + if netErr != nil { + return netErr + } + } + + return nil +} + +func (u *user) stop() { + u.events <- eventStop{} + <-u.done +} + +func (u *user) hasPersistentMsgStore() bool { + if u.msgStore == nil { + return false + } + _, isMem := u.msgStore.(*memoryMessageStore) + return !isMem +} + +// localAddrForHost returns the local address to use when connecting to host. +// A nil address is returned when the OS should automatically pick one. +func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) { + upstreamUserIPs := u.srv.Config().UpstreamUserIPs + if len(upstreamUserIPs) == 0 { + return nil, nil + } + + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) + if err != nil { + return nil, err + } + + wantIPv6 := false + for _, ip := range ips { + if ip.To4() == nil { + wantIPv6 = true + break + } + } + + var ipNet *net.IPNet + for _, in := range upstreamUserIPs { + if wantIPv6 == (in.IP.To4() == nil) { + ipNet = in + break + } + } + if ipNet == nil { + return nil, nil + } + + var ipInt big.Int + ipInt.SetBytes(ipNet.IP) + ipInt.Add(&ipInt, big.NewInt(u.ID+1)) + ip := net.IP(ipInt.Bytes()) + if !ipNet.Contains(ip) { + return nil, fmt.Errorf("IP network %v too small", ipNet) + } + + return &net.TCPAddr{IP: ip}, nil +} diff --git a/trunk/version.go b/trunk/version.go new file mode 100644 index 0000000..0e5d7a6 --- /dev/null +++ b/trunk/version.go @@ -0,0 +1,50 @@ +package suika + +import ( + "fmt" + "runtime/debug" + "strings" +) + +const ( + defaultVersion = "0.0.0" + defaultCommit = "HEAD" + defaultBuild = "0000-01-01:00:00+00:00" +) + +var ( + // Version is the tagged release version in the form .. + // following semantic versioning and is overwritten by the build system. + Version = defaultVersion + + // Commit is the commit sha of the build (normally from Git) and is overwritten + // by the build system. + Commit = defaultCommit + + // Build is the date and time of the build as an RFC3339 formatted string + // and is overwritten by the build system. + Build = defaultBuild +) + +// FullVersion display the full version and build +func FullVersion() string { + var sb strings.Builder + + isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild + + if !isDefault { + sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build)) + } + + if info, ok := debug.ReadBuildInfo(); ok { + if isDefault { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Version)) + } + sb.WriteString(fmt.Sprintf(" %s", info.GoVersion)) + if info.Main.Sum != "" { + sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum)) + } + } + + return sb.String() +}