--- /dev/null
+vendor
+/suika
+/suikadb
+/suika-znc-import
+/suika.db
--- /dev/null
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+ <one line to give the program's name and a brief idea of what it does.>
+ Copyright (C) <year> <name of author>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+<https://www.gnu.org/licenses/>.
--- /dev/null
+GO ?= go
+RM ?= rm
+GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor
+PREFIX ?= /usr/local
+BINDIR ?= bin
+MANDIR ?= share/man
+MKDIR ?= mkdir
+CP ?= cp
+SYSCONFDIR ?= /etc
+ASCIIDOCTOR ?= asciidoctor
+
+VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"`
+COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"`
+BRANCH = `git rev-parse --abbrev-ref HEAD`
+BUILD = `git show -s --pretty=format:%cI`
+
+GOARCH ?= amd64
+GOOS ?= linux
+
+all: build
+
+build: vendor
+ ${GO} build ${GOFLAGS} ./cmd/suika
+ ${GO} build ${GOFLAGS} ./cmd/suikadb
+ ${GO} build ${GOFLAGS} ./cmd/suika-znc-import
+clean:
+ ${RM} -f suika suikadb suika-znc-import
+install:
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR}
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7
+ ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika
+ ${MKDIR} -p ${DESTDIR}/var/lib/suika
+ ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR}
+ ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1
+ ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config
+test:
+ go test
+vendor:
+ go mod vendor
+.PHONY: build clean install
--- /dev/null
+# suika
+
+[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika)
+
+A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power)
+
+- Multi-user
+- Support multiple clients for a single user, with proper backlog
+ synchronization
+- Support connecting to multiple upstream servers via a single IRC connection
+ to the bouncer
+
+## Building and installing
+
+Dependencies:
+
+- Go
+- BSD or GNU make
+
+For end users, a `Makefile` is provided:
+
+ make
+ doas make install
+
+For development, you can use `go run ./cmd/suika` as usual.
+
+## License
+AGPLv3, see [LICENSE](LICENSE).
+
+* Copyright (C) 2020 The soju Contributors
+* Copyright (C) 2023-present Izuru Yakumo
+
+The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gopkg.in/irc.v3"
+)
+
+func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) {
+ if !ch.complete {
+ panic("Tried to forward a partial channel")
+ }
+
+ // RPL_NOTOPIC shouldn't be sent on JOIN
+ if ch.Topic != "" {
+ sendTopic(dc, ch)
+ }
+
+ if dc.caps["soju.im/read"] {
+ channelCM := ch.conn.network.casemap(ch.Name)
+ r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err)
+ } else {
+ timestampStr := "*"
+ if r != nil {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr},
+ })
+ }
+ }
+
+ sendNames(dc, ch)
+}
+
+func sendTopic(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ if ch.Topic != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_TOPIC,
+ Params: []string{dc.nick, downstreamName, ch.Topic},
+ })
+ if ch.TopicWho != nil {
+ topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho)
+ topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime},
+ })
+ }
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NOTOPIC,
+ Params: []string{dc.nick, downstreamName, "No topic is set"},
+ })
+ }
+}
+
+func sendNames(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ emptyNameReply := &irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, ""},
+ }
+ maxLength := maxMessageLength - len(emptyNameReply.String())
+
+ var buf strings.Builder
+ for _, entry := range ch.Members.innerMap {
+ nick := entry.originalKey
+ memberships := entry.value.(*memberships)
+ s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick)
+
+ n := buf.Len() + 1 + len(s)
+ if buf.Len() != 0 && n > maxLength {
+ // There's not enough space for the next space + nick.
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ buf.Reset()
+ }
+
+ if buf.Len() != 0 {
+ buf.WriteByte(' ')
+ }
+ buf.WriteString(s)
+ }
+
+ if buf.Len() != 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, downstreamName, "End of /NAMES list"},
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/ed25519"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "math/big"
+ "time"
+)
+
+func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) {
+ var (
+ privKey crypto.PrivateKey
+ pubKey crypto.PublicKey
+ )
+ switch keyType {
+ case "rsa":
+ key, err := rsa.GenerateKey(rand.Reader, bits)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ecdsa":
+ key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ed25519":
+ var err error
+ pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ // Using PKCS#8 allows easier extension for new key types.
+ privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ notBefore := time.Now()
+ // Lets make a fair assumption nobody will use the same cert for more than 20 years...
+ notAfter := notBefore.Add(24 * time.Hour * 365 * 20)
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ return nil, nil, err
+ }
+ cert := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{CommonName: "suika auto-generated certificate"},
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
+ certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return privKeyBytes, certBytes, nil
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/url"
+ "os"
+ "strings"
+ "unicode"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+const usage = `usage: suika-znc-import [options...] <znc config path>
+
+Imports configuration from a ZNC file. Users and networks are merged if they
+already exist in the suika database. ZNC settings overwrite existing suika
+settings.
+
+Options:
+
+ -help Show this help message
+ -config <path> Path to suika config file
+ -user <username> Limit import to username (may be specified multiple times)
+ -network <name> Limit import to network (may be specified multiple times)
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ users := make(map[string]bool)
+ networks := make(map[string]bool)
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Var((*stringSetFlag)(&users), "user", "")
+ flag.Var((*stringSetFlag)(&networks), "network", "")
+ flag.Parse()
+
+ zncPath := flag.Arg(0)
+ if zncPath == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ ctx := context.Background()
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+ defer db.Close()
+
+ f, err := os.Open(zncPath)
+ if err != nil {
+ log.Fatalf("failed to open ZNC configuration file: %v", err)
+ }
+ defer f.Close()
+
+ zp := zncParser{bufio.NewReader(f), 1}
+ root, err := zp.sectionBody("", "")
+ if err != nil {
+ log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
+ }
+
+ l, err := db.ListUsers(ctx)
+ if err != nil {
+ log.Fatalf("failed to list users in DB: %v", err)
+ }
+ existingUsers := make(map[string]*suika.User, len(l))
+ for i, u := range l {
+ existingUsers[u.Username] = &l[i]
+ }
+
+ usersCreated := 0
+ usersImported := 0
+ networksImported := 0
+ channelsImported := 0
+ root.ForEach("User", func(section *zncSection) {
+ username := section.Name
+ if len(users) > 0 && !users[username] {
+ return
+ }
+ usersImported++
+
+ u, ok := existingUsers[username]
+ if ok {
+ log.Printf("user %q: updating existing user", username)
+ } else {
+ // "!!" is an invalid crypt format, thus disables password auth
+ u = &suika.User{Username: username, Password: "!!"}
+ usersCreated++
+ log.Printf("user %q: creating new user", username)
+ }
+
+ u.Admin = section.Values.Get("Admin") == "true"
+
+ if err := db.StoreUser(ctx, u); err != nil {
+ log.Fatalf("failed to store user %q: %v", username, err)
+ }
+ userID := u.ID
+
+ l, err := db.ListNetworks(ctx, userID)
+ if err != nil {
+ log.Fatalf("failed to list networks for user %q: %v", username, err)
+ }
+ existingNetworks := make(map[string]*suika.Network, len(l))
+ for i, n := range l {
+ existingNetworks[n.GetName()] = &l[i]
+ }
+
+ nick := section.Values.Get("Nick")
+ realname := section.Values.Get("RealName")
+ ident := section.Values.Get("Ident")
+
+ section.ForEach("Network", func(section *zncSection) {
+ netName := section.Name
+ if len(networks) > 0 && !networks[netName] {
+ return
+ }
+ networksImported++
+
+ logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
+ logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
+
+ netNick := section.Values.Get("Nick")
+ if netNick == "" {
+ netNick = nick
+ }
+ netRealname := section.Values.Get("RealName")
+ if netRealname == "" {
+ netRealname = realname
+ }
+ netIdent := section.Values.Get("Ident")
+ if netIdent == "" {
+ netIdent = ident
+ }
+
+ for _, name := range section.Values["LoadModule"] {
+ switch name {
+ case "sasl":
+ logger.Printf("warning: SASL credentials not imported")
+ case "nickserv":
+ logger.Printf("warning: NickServ credentials not imported")
+ case "perform":
+ logger.Printf("warning: \"perform\" plugin commands not imported")
+ }
+ }
+
+ u, pass, err := importNetworkServer(section.Values.Get("Server"))
+ if err != nil {
+ logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
+ }
+
+ n, ok := existingNetworks[netName]
+ if ok {
+ logger.Printf("updating existing network")
+ } else {
+ n = &suika.Network{Name: netName}
+ logger.Printf("creating new network")
+ }
+
+ n.Addr = u.String()
+ n.Nick = netNick
+ n.Username = netIdent
+ n.Realname = netRealname
+ n.Pass = pass
+ n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
+
+ if err := db.StoreNetwork(ctx, userID, n); err != nil {
+ logger.Fatalf("failed to store network: %v", err)
+ }
+
+ l, err := db.ListChannels(ctx, n.ID)
+ if err != nil {
+ logger.Fatalf("failed to list channels: %v", err)
+ }
+ existingChannels := make(map[string]*suika.Channel, len(l))
+ for i, ch := range l {
+ existingChannels[ch.Name] = &l[i]
+ }
+
+ section.ForEach("Chan", func(section *zncSection) {
+ chName := section.Name
+
+ if section.Values.Get("Disabled") == "true" {
+ logger.Printf("skipping import of disabled channel %q", chName)
+ return
+ }
+
+ channelsImported++
+
+ ch, ok := existingChannels[chName]
+ if ok {
+ logger.Printf("channel %q: updating existing channel", chName)
+ } else {
+ ch = &suika.Channel{Name: chName}
+ logger.Printf("channel %q: creating new channel", chName)
+ }
+
+ ch.Key = section.Values.Get("Key")
+ ch.Detached = section.Values.Get("Detached") == "true"
+
+ if err := db.StoreChannel(ctx, n.ID, ch); err != nil {
+ logger.Printf("channel %q: failed to store channel: %v", chName, err)
+ }
+ })
+ })
+ })
+
+ if err := db.Close(); err != nil {
+ log.Printf("failed to close database: %v", err)
+ }
+
+ if usersCreated > 0 {
+ log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password <username>`")
+ }
+
+ log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
+}
+
+func importNetworkServer(s string) (u *url.URL, pass string, err error) {
+ parts := strings.Fields(s)
+ if len(parts) < 2 {
+ return nil, "", fmt.Errorf("expected space-separated host and port")
+ }
+
+ scheme := "irc"
+ host := parts[0]
+ port := parts[1]
+ if strings.HasPrefix(port, "+") {
+ port = port[1:]
+ scheme = "ircs"
+ }
+
+ if len(parts) > 2 {
+ pass = parts[2]
+ }
+
+ u = &url.URL{
+ Scheme: scheme,
+ Host: host + ":" + port,
+ }
+ return u, pass, nil
+}
+
+type zncSection struct {
+ Type string
+ Name string
+ Values zncValues
+ Children []zncSection
+}
+
+func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
+ for _, section := range s.Children {
+ if section.Type == typ {
+ f(§ion)
+ }
+ }
+}
+
+type zncValues map[string][]string
+
+func (zv zncValues) Get(k string) string {
+ if len(zv[k]) == 0 {
+ return ""
+ }
+ return zv[k][0]
+}
+
+type zncParser struct {
+ br *bufio.Reader
+ line int
+}
+
+func (zp *zncParser) readByte() (byte, error) {
+ b, err := zp.br.ReadByte()
+ if b == '\n' {
+ zp.line++
+ }
+ return b, err
+}
+
+func (zp *zncParser) readRune() (rune, int, error) {
+ r, n, err := zp.br.ReadRune()
+ if r == '\n' {
+ zp.line++
+ }
+ return r, n, err
+}
+
+func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
+ section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
+
+Loop:
+ for {
+ if err := zp.skipSpace(); err != nil {
+ return nil, err
+ }
+
+ b, err := zp.br.Peek(2)
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ switch b[0] {
+ case '<':
+ if b[1] == '/' {
+ break Loop
+ } else {
+ childType, childName, err := zp.sectionHeader()
+ if err != nil {
+ return nil, err
+ }
+ child, err := zp.sectionBody(childType, childName)
+ if err != nil {
+ return nil, err
+ }
+ if footerType, err := zp.sectionFooter(); err != nil {
+ return nil, err
+ } else if footerType != childType {
+ return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
+ }
+ section.Children = append(section.Children, *child)
+ }
+ case '/':
+ if b[1] == '/' {
+ if err := zp.skipComment(); err != nil {
+ return nil, err
+ }
+ break
+ }
+ fallthrough
+ default:
+ k, v, err := zp.keyValuePair()
+ if err != nil {
+ return nil, err
+ }
+ section.Values[k] = append(section.Values[k], v)
+ }
+ }
+
+ return section, nil
+}
+
+func (zp *zncParser) skipSpace() error {
+ for {
+ r, _, err := zp.readRune()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if !unicode.IsSpace(r) {
+ zp.br.UnreadRune()
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) skipComment() error {
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+
+ for {
+ b, err := zp.readByte()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if b == '\n' {
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) sectionHeader() (string, string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", "", err
+ }
+ typ, err := zp.readWord(' ')
+ if err != nil {
+ return "", "", err
+ }
+ name, err := zp.readWord('>')
+ return typ, name, err
+}
+
+func (zp *zncParser) sectionFooter() (string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return "", err
+ }
+ return zp.readWord('>')
+}
+
+func (zp *zncParser) keyValuePair() (string, string, error) {
+ k, err := zp.readWord('=')
+ if err != nil {
+ return "", "", err
+ }
+ v, err := zp.readWord('\n')
+ return strings.TrimSpace(k), strings.TrimSpace(v), err
+}
+
+func (zp *zncParser) expectRune(expected rune) error {
+ r, _, err := zp.readRune()
+ if err != nil {
+ return err
+ } else if r != expected {
+ return fmt.Errorf("expected %q, got %q", expected, r)
+ }
+ return nil
+}
+
+func (zp *zncParser) readWord(delim byte) (string, error) {
+ var sb strings.Builder
+ for {
+ b, err := zp.readByte()
+ if err != nil {
+ return "", err
+ }
+
+ if b == delim {
+ return sb.String(), nil
+ }
+ if b == '\n' {
+ return "", fmt.Errorf("expected %q before newline", delim)
+ }
+
+ sb.WriteByte(b)
+ }
+}
+
+type stringSetFlag map[string]bool
+
+func (v *stringSetFlag) String() string {
+ return fmt.Sprint(map[string]bool(*v))
+}
+
+func (v *stringSetFlag) Set(s string) error {
+ (*v)[s] = true
+ return nil
+}
--- /dev/null
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/url"
+ "os"
+ "os/signal"
+ "strings"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+// TCP keep-alive interval for downstream TCP connections
+const downstreamKeepAlive = 1 * time.Hour
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+func bumpOpenedFileLimit() error {
+ var rlimit syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
+ }
+ rlimit.Cur = rlimit.Max
+ if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
+ }
+ return nil
+}
+
+var (
+ configPath string
+ debug bool
+
+ tlsCert atomic.Value // *tls.Certificate
+)
+
+func loadConfig() (*config.Server, *suika.Config, error) {
+ var raw *config.Server
+ if configPath != "" {
+ var err error
+ raw, err = config.Load(configPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load config file: %v", err)
+ }
+ } else {
+ raw = config.Defaults()
+ }
+
+ var motd string
+ if raw.MOTDPath != "" {
+ b, err := os.ReadFile(raw.MOTDPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
+ }
+ motd = strings.TrimSuffix(string(b), "\n")
+ }
+
+ if raw.TLS != nil {
+ cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
+ }
+ tlsCert.Store(&cert)
+ }
+
+ cfg := &suika.Config{
+ Hostname: raw.Hostname,
+ Title: raw.Title,
+ LogPath: raw.LogPath,
+ MaxUserNetworks: raw.MaxUserNetworks,
+ MultiUpstream: raw.MultiUpstream,
+ UpstreamUserIPs: raw.UpstreamUserIPs,
+ MOTD: motd,
+ }
+ return raw, cfg, nil
+}
+
+func main() {
+ var listen []string
+ flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.BoolVar(&debug, "debug", false, "enable debug logging")
+ flag.Parse()
+
+ cfg, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ cfg.Listen = append(cfg.Listen, listen...)
+ if len(cfg.Listen) == 0 {
+ cfg.Listen = []string{":6667"}
+ }
+
+ if err := bumpOpenedFileLimit(); err != nil {
+ log.Printf("failed to bump max number of opened files: %v", err)
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ var tlsCfg *tls.Config
+ if cfg.TLS != nil {
+ tlsCfg = &tls.Config{
+ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return tlsCert.Load().(*tls.Certificate), nil
+ },
+ }
+ }
+
+ srv := suika.NewServer(db)
+ srv.SetConfig(serverCfg)
+ srv.Logger = suika.NewLogger(log.Writer(), debug)
+
+ for _, listen := range cfg.Listen {
+ listen := listen // copy
+ listenURI := listen
+ if !strings.Contains(listenURI, ":/") {
+ // This is a raw domain name, make it an URL with an empty scheme
+ listenURI = "//" + listenURI
+ }
+ u, err := url.Parse(listenURI)
+ if err != nil {
+ log.Fatalf("failed to parse listen URI %q: %v", listen, err)
+ }
+
+ switch u.Scheme {
+ case "ircs", "":
+ if tlsCfg == nil {
+ log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
+ }
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6697"
+ }
+ ircsTLSCfg := tlsCfg.Clone()
+ ircsTLSCfg.NextProtos = []string{"irc"}
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ l, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
+ }
+ ln := tls.NewListener(l, ircsTLSCfg)
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "irc":
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6667"
+ }
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ ln, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "unix":
+ ln, err := net.Listen("unix", u.Path)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ default:
+ log.Fatalf("failed to listen on %q: unsupported scheme", listen)
+ }
+
+ log.Printf("starting suika version %v\n", suika.FullVersion())
+ log.Printf("server listening on %q", listen)
+ }
+
+ sigCh := make(chan os.Signal, 1)
+ signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
+
+ if err := srv.Start(); err != nil {
+ log.Fatal(err)
+ }
+
+ for sig := range sigCh {
+ switch sig {
+ case syscall.SIGHUP:
+ log.Print("reloading configuration")
+ _, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Printf("failed to reloading configuration: %v", err)
+ } else {
+ srv.SetConfig(serverCfg)
+ }
+ case syscall.SIGINT, syscall.SIGTERM:
+ log.Print("shutting down server")
+ srv.Shutdown()
+ return
+ }
+ }
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+ "golang.org/x/crypto/bcrypt"
+ "golang.org/x/term"
+)
+
+const usage = `usage: suikadb [-config path] <action> [options...]
+
+ create-user <username> [-admin] Create a new user
+ change-password <username> Change password for a user
+ help Show this help message
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Parse()
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ ctx := context.Background()
+
+ switch cmd := flag.Arg(0); cmd {
+ case "create-user":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ fs := flag.NewFlagSet("", flag.ExitOnError)
+ admin := fs.Bool("admin", false, "make the new user admin")
+ fs.Parse(flag.Args()[2:])
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user := suika.User{
+ Username: username,
+ Password: string(hashed),
+ Admin: *admin,
+ }
+ if err := db.StoreUser(ctx, &user); err != nil {
+ log.Fatalf("failed to create user: %v", err)
+ }
+ case "change-password":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ user, err := db.GetUser(ctx, username)
+ if err != nil {
+ log.Fatalf("failed to get user: %v", err)
+ }
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user.Password = string(hashed)
+ if err := db.StoreUser(ctx, user); err != nil {
+ log.Fatalf("failed to update password: %v", err)
+ }
+ case "version":
+ fmt.Printf("%v\n", suika.FullVersion())
+ default:
+ flag.Usage()
+ if cmd != "help" {
+ os.Exit(1)
+ }
+ }
+}
+
+func readPassword() ([]byte, error) {
+ var password []byte
+ var err error
+ fd := int(os.Stdin.Fd())
+
+ if term.IsTerminal(fd) {
+ fmt.Printf("Password: ")
+ password, err = term.ReadPassword(int(os.Stdin.Fd()))
+ if err != nil {
+ return nil, err
+ }
+ fmt.Printf("\n")
+ } else {
+ fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n")
+ // TODO: the buffering messes up repeated calls to readPassword
+ scanner := bufio.NewScanner(os.Stdin)
+ if !scanner.Scan() {
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ return nil, io.ErrUnexpectedEOF
+ }
+ password = scanner.Bytes()
+
+ if len(password) == 0 {
+ return nil, fmt.Errorf("zero length password")
+ }
+ }
+
+ return password, nil
+}
--- /dev/null
+db sqlite3 /var/lib/suika/main.db
+log fs /var/lib/suika/logs/
--- /dev/null
+package config
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+
+ "git.sr.ht/~emersion/go-scfg"
+)
+
+type TLS struct {
+ CertPath, KeyPath string
+}
+
+type Server struct {
+ Listen []string
+ TLS *TLS
+ Hostname string
+ Title string
+ MOTDPath string
+
+ SQLDriver string
+ SQLSource string
+ LogPath string
+
+ MaxUserNetworks int
+ MultiUpstream bool
+ UpstreamUserIPs []*net.IPNet
+}
+
+func Defaults() *Server {
+ hostname, err := os.Hostname()
+ if err != nil {
+ hostname = "localhost"
+ }
+ return &Server{
+ Hostname: hostname,
+ SQLDriver: "sqlite3",
+ SQLSource: "suika.db",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ }
+}
+
+func Load(path string) (*Server, error) {
+ cfg, err := scfg.Load(path)
+ if err != nil {
+ return nil, err
+ }
+ return parse(cfg)
+}
+
+func parse(cfg scfg.Block) (*Server, error) {
+ srv := Defaults()
+ for _, d := range cfg {
+ switch d.Name {
+ case "listen":
+ var uri string
+ if err := d.ParseParams(&uri); err != nil {
+ return nil, err
+ }
+ srv.Listen = append(srv.Listen, uri)
+ case "hostname":
+ if err := d.ParseParams(&srv.Hostname); err != nil {
+ return nil, err
+ }
+ case "title":
+ if err := d.ParseParams(&srv.Title); err != nil {
+ return nil, err
+ }
+ case "motd":
+ if err := d.ParseParams(&srv.MOTDPath); err != nil {
+ return nil, err
+ }
+ case "tls":
+ tls := &TLS{}
+ if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
+ return nil, err
+ }
+ srv.TLS = tls
+ case "db":
+ if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
+ return nil, err
+ }
+ case "log":
+ var driver string
+ if err := d.ParseParams(&driver, &srv.LogPath); err != nil {
+ return nil, err
+ }
+ if driver != "fs" {
+ return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver)
+ }
+ case "max-user-networks":
+ var max string
+ if err := d.ParseParams(&max); err != nil {
+ return nil, err
+ }
+ var err error
+ if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ case "multi-upstream-mode":
+ var str string
+ if err := d.ParseParams(&str); err != nil {
+ return nil, err
+ }
+ v, err := strconv.ParseBool(str)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ srv.MultiUpstream = v
+ case "upstream-user-ip":
+ if len(srv.UpstreamUserIPs) > 0 {
+ return nil, fmt.Errorf("directive %q: can only be specified once", d.Name)
+ }
+ var hasIPv4, hasIPv6 bool
+ for _, s := range d.Params {
+ _, n, err := net.ParseCIDR(s)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
+ }
+ if n.IP.To4() == nil {
+ if hasIPv6 {
+ return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name)
+ }
+ hasIPv6 = true
+ } else {
+ if hasIPv4 {
+ return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name)
+ }
+ hasIPv4 = true
+ }
+ srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
+ }
+ default:
+ return nil, fmt.Errorf("unknown directive %q", d.Name)
+ }
+ }
+
+ return srv, nil
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/time/rate"
+ "gopkg.in/irc.v3"
+)
+
+// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
+// reading and writing IRC messages.
+type ircConn interface {
+ ReadMessage() (*irc.Message, error)
+ WriteMessage(*irc.Message) error
+ Close() error
+ SetReadDeadline(time.Time) error
+ SetWriteDeadline(time.Time) error
+ RemoteAddr() net.Addr
+ LocalAddr() net.Addr
+}
+
+func newNetIRCConn(c net.Conn) ircConn {
+ type netConn net.Conn
+ return struct {
+ *irc.Conn
+ netConn
+ }{irc.NewConn(c), c}
+}
+
+type connOptions struct {
+ Logger Logger
+ RateLimitDelay time.Duration
+ RateLimitBurst int
+}
+
+type conn struct {
+ conn ircConn
+ srv *Server
+ logger Logger
+
+ lock sync.Mutex
+ outgoing chan<- *irc.Message
+ closed bool
+ closedCh chan struct{}
+}
+
+func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
+ outgoing := make(chan *irc.Message, 64)
+ c := &conn{
+ conn: ic,
+ srv: srv,
+ outgoing: outgoing,
+ logger: options.Logger,
+ closedCh: make(chan struct{}),
+ }
+
+ go func() {
+ ctx, cancel := c.NewContext(context.Background())
+ defer cancel()
+
+ rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
+ for msg := range outgoing {
+ if err := rl.Wait(ctx); err != nil {
+ break
+ }
+
+ c.logger.Debugf("sent: %v", msg)
+ c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
+ if err := c.conn.WriteMessage(msg); err != nil {
+ c.logger.Printf("failed to write message: %v", err)
+ break
+ }
+ }
+ if err := c.conn.Close(); err != nil && !isErrClosed(err) {
+ c.logger.Printf("failed to close connection: %v", err)
+ } else {
+ c.logger.Debugf("connection closed")
+ }
+ // Drain the outgoing channel to prevent SendMessage from blocking
+ for range outgoing {
+ // This space is intentionally left blank
+ }
+ }()
+
+ c.logger.Debugf("new connection")
+ return c
+}
+
+func (c *conn) isClosed() bool {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ return c.closed
+}
+
+// Close closes the connection. It is safe to call from any goroutine.
+func (c *conn) Close() error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return fmt.Errorf("connection already closed")
+ }
+
+ err := c.conn.Close()
+ c.closed = true
+ close(c.outgoing)
+ close(c.closedCh)
+ return err
+}
+
+func (c *conn) ReadMessage() (*irc.Message, error) {
+ msg, err := c.conn.ReadMessage()
+ if isErrClosed(err) {
+ return nil, io.EOF
+ } else if err != nil {
+ return nil, err
+ }
+
+ c.logger.Debugf("received: %v", msg)
+ return msg, nil
+}
+
+// SendMessage queues a new outgoing message. It is safe to call from any
+// goroutine.
+//
+// If the connection is closed before the message is sent, SendMessage silently
+// drops the message.
+func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return
+ }
+
+ select {
+ case c.outgoing <- msg:
+ // Success
+ case <-ctx.Done():
+ c.logger.Printf("failed to send message: %v", ctx.Err())
+ }
+}
+
+func (c *conn) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+
+func (c *conn) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+// NewContext returns a copy of the parent context with a new Done channel. The
+// returned context's Done channel is closed when the connection is closed,
+// when the returned cancel function is called, or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// Canceling this context releases resources associated with it, so code should
+// call cancel as soon as the operations running in this Context complete.
+func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(parent)
+
+ go func() {
+ defer cancel()
+
+ select {
+ case <-ctx.Done():
+ // The parent context has been cancelled, or the caller has called
+ // cancel()
+ case <-c.closedCh:
+ // The connection has been closed
+ }
+ }()
+
+ return ctx, cancel
+}
--- /dev/null
+#!/bin/sh -eu
+
+# Converts a log dir to its case-mapped form.
+#
+# suika needs to be stopped for this script to work properly. The script may
+# re-order messages that happened within the same second interval if merging
+# two daily log files is necessary.
+#
+# usage: casemap-logs.sh <directory>
+
+root="$1"
+
+for net_dir in "$root"/*/*; do
+ for chan in $(ls "$net_dir"); do
+ cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')"
+ if [ "$chan" = "$cm_chan" ]; then
+ continue
+ fi
+
+ if ! [ -d "$net_dir/$cm_chan" ]; then
+ echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ mv "$net_dir/$chan" "$net_dir/$cm_chan"
+ continue
+ fi
+
+ echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ for day in $(ls "$net_dir/$chan"); do
+ if ! [ -e "$net_dir/$cm_chan/$day" ]; then
+ echo >&2 " Moving log file: '$day'"
+ mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day"
+ continue
+ fi
+
+ echo >&2 " Merging log file: '$day'"
+ sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new"
+ mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day"
+ rm "$net_dir/$chan/$day"
+ done
+
+ rmdir "$net_dir/$chan"
+ done
+done
--- /dev/null
+# Clients
+
+This page describes how to configure IRC clients to better integrate with soju.
+
+Also see the [IRCv3 support tables] for a more general list of clients.
+
+# catgirl
+
+catgirl doesn't properly implement cap-3.2, so many capabilities will be
+disabled. catgirl developers have publicly stated that supporting bouncers such
+as soju is a non-goal.
+
+# [Emacs]
+
+There are two clients provided with Emacs. They require some setup to work
+properly.
+
+## Erc
+
+You need to explicitly set the username, which is the defcustom
+`erc-email-userid`.
+
+```elisp
+(setq erc-email-userid "<username>/irc.libera.chat") ;; Example with Libera.Chat
+(defun run-erc ()
+ (interactive)
+ (erc-tls :server "<server>"
+ :port 6697
+ :nick "<nick>"
+ :password "<password>"))
+```
+
+Then run `M-x run-erc`.
+
+## Rcirc
+
+The only thing needed here is the general config:
+
+```elisp
+(setq rcirc-server-alist
+ '(("<server>"
+ :port 6697
+ :encryption tls
+ :nick "<nick>"
+ :user-name "<username>/irc.libera.chat" ;; Example with Libera.Chat
+ :password "<password>")))
+```
+
+Then run `M-x irc`.
+
+# [gamja]
+
+gamja has been designed together with soju, so should have excellent
+integration. gamja supports many IRCv3 features including chat history.
+gamja also provides UI to manage soju networks via the
+`soju.im/bouncer-networks` extension.
+
+# [goguma]
+
+Much like gamja, goguma has been designed together with soju, so should have
+excellent integration. goguma supports many IRCv3 features including chat
+history. goguma should seamlessly connect to all networks configured in soju via
+the `soju.im/bouncer-networks` extension.
+
+# [Hexchat]
+
+Hexchat has support for a small set of IRCv3 capabilities. To prevent
+automatically reconnecting to channels parted from soju, and prevent buffering
+outgoing messages:
+
+ /set irc_reconnect_rejoin off
+ /set net_throttle off
+
+# [senpai]
+
+senpai is being developed with soju in mind, so should have excellent
+integration. senpai supports many IRCv3 features including chat history.
+
+# [Weechat]
+
+A [Weechat script] is available to provide better integration with soju.
+The script will automatically connect to all of your networks once a
+single connection to soju is set up in Weechat.
+
+On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them:
+
+ /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names
+ /save
+ /reconnect -all
+
+See `/help cap` for more information.
+
+[IRCv3 support tables]: https://ircv3.net/software/clients
+[gamja]: https://sr.ht/~emersion/gamja/
+[goguma]: https://sr.ht/~emersion/goguma/
+[senpai]: https://sr.ht/~taiite/senpai/
+[Weechat]: https://weechat.org/
+[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py
+[Hexchat]: https://hexchat.github.io/
+[Emacs]: https://www.gnu.org/software/emacs/
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+ "time"
+)
+
+type Database interface {
+ Close() error
+ Stats(ctx context.Context) (*DatabaseStats, error)
+
+ ListUsers(ctx context.Context) ([]User, error)
+ GetUser(ctx context.Context, username string) (*User, error)
+ StoreUser(ctx context.Context, user *User) error
+ DeleteUser(ctx context.Context, id int64) error
+
+ ListNetworks(ctx context.Context, userID int64) ([]Network, error)
+ StoreNetwork(ctx context.Context, userID int64, network *Network) error
+ DeleteNetwork(ctx context.Context, id int64) error
+ ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
+ StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
+ DeleteChannel(ctx context.Context, id int64) error
+
+ ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
+ StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
+
+ GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error)
+ StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error
+}
+
+func OpenDB(driver, source string) (Database, error) {
+ switch driver {
+ case "sqlite3":
+ return OpenSqliteDB(source)
+ case "postgres":
+ return OpenPostgresDB(source)
+ default:
+ return nil, fmt.Errorf("unsupported database driver: %q", driver)
+ }
+}
+
+type DatabaseStats struct {
+ Users int64
+ Networks int64
+ Channels int64
+}
+
+type User struct {
+ ID int64
+ Username string
+ Password string // hashed
+ Realname string
+ Admin bool
+}
+
+type SASL struct {
+ Mechanism string
+
+ Plain struct {
+ Username string
+ Password string
+ }
+
+ // TLS client certificate authentication.
+ External struct {
+ // X.509 certificate in DER form.
+ CertBlob []byte
+ // PKCS#8 private key in DER form.
+ PrivKeyBlob []byte
+ }
+}
+
+type Network struct {
+ ID int64
+ Name string
+ Addr string
+ Nick string
+ Username string
+ Realname string
+ Pass string
+ ConnectCommands []string
+ SASL SASL
+ Enabled bool
+}
+
+func (net *Network) GetName() string {
+ if net.Name != "" {
+ return net.Name
+ }
+ return net.Addr
+}
+
+func (net *Network) URL() (*url.URL, error) {
+ s := net.Addr
+ if !strings.Contains(s, "://") {
+ // This is a raw domain name, make it an URL with the default scheme
+ s = "ircs://" + s
+ }
+
+ u, err := url.Parse(s)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
+ }
+
+ return u, nil
+}
+
+func GetNick(user *User, net *Network) string {
+ if net.Nick != "" {
+ return net.Nick
+ }
+ return user.Username
+}
+
+func GetUsername(user *User, net *Network) string {
+ if net.Username != "" {
+ return net.Username
+ }
+ return GetNick(user, net)
+}
+
+func GetRealname(user *User, net *Network) string {
+ if net.Realname != "" {
+ return net.Realname
+ }
+ if user.Realname != "" {
+ return user.Realname
+ }
+ return GetNick(user, net)
+}
+
+type MessageFilter int
+
+const (
+ // TODO: use customizable user defaults for FilterDefault
+ FilterDefault MessageFilter = iota
+ FilterNone
+ FilterHighlight
+ FilterMessage
+)
+
+func parseFilter(filter string) (MessageFilter, error) {
+ switch filter {
+ case "default":
+ return FilterDefault, nil
+ case "none":
+ return FilterNone, nil
+ case "highlight":
+ return FilterHighlight, nil
+ case "message":
+ return FilterMessage, nil
+ }
+ return 0, fmt.Errorf("unknown filter: %q", filter)
+}
+
+type Channel struct {
+ ID int64
+ Name string
+ Key string
+
+ Detached bool
+ DetachedInternalMsgID string
+
+ RelayDetached MessageFilter
+ ReattachOn MessageFilter
+ DetachAfter time.Duration
+ DetachOn MessageFilter
+}
+
+type DeliveryReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Client string
+ InternalMsgID string
+}
+
+type ReadReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Timestamp time.Time
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "errors"
+ "fmt"
+ "math"
+ "strings"
+ "time"
+
+ _ "github.com/lib/pq"
+)
+
+const postgresQueryTimeout = 5 * time.Second
+
+const postgresConfigSchema = `
+CREATE TABLE IF NOT EXISTS "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+`
+//go:embed suika_psql_schema.sql
+var postgresSchema string
+
+var postgresMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
+ `
+ CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+ ALTER TABLE "Network"
+ ALTER COLUMN sasl_mechanism
+ TYPE sasl_mechanism
+ USING sasl_mechanism::sasl_mechanism;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type PostgresDB struct {
+ db *sql.DB
+}
+
+func OpenPostgresDB(source string) (Database, error) {
+ sqlPostgresDB, err := sql.Open("postgres", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &PostgresDB{db: sqlPostgresDB}
+ if err := db.upgrade(); err != nil {
+ sqlPostgresDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *PostgresDB) upgrade() error {
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if _, err := tx.Exec(postgresConfigSchema); err != nil {
+ return fmt.Errorf("failed to create Config table: %s", err)
+ }
+
+ var version int
+ err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return fmt.Errorf("failed to query schema version: %s", err)
+ }
+
+ if version == len(postgresMigrations) {
+ return nil
+ }
+ if version > len(postgresMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version)
+ }
+
+ if version == 0 {
+ if _, err := tx.Exec(postgresSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %s", err)
+ }
+ } else {
+ for i := version; i < len(postgresMigrations); i++ {
+ if _, err := tx.Exec(postgresMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
+ ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) Close() error {
+ return db.db.Close()
+}
+
+func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM "User") AS users,
+ (SELECT COUNT(*) FROM "Network") AS networks,
+ (SELECT COUNT(*) FROM "Channel") AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ `SELECT id, username, password, admin, realname FROM "User"`)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ password := toNullString(user.Password)
+ realname := toNullString(user.Realname)
+
+ var err error
+ if user.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "User" (username, password, admin, realname)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`,
+ user.Username, password, user.Admin, realname).Scan(&user.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "User"
+ SET password = $1, admin = $2, realname = $3
+ WHERE id = $4`,
+ password, user.Admin, realname, user.ID)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
+ sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
+ FROM "Network"
+ WHERE "user" = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ netName := toNullString(network.Name)
+ nick := toNullString(network.Nick)
+ netUsername := toNullString(network.Username)
+ realname := toNullString(network.Realname)
+ pass := toNullString(network.Pass)
+ connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ var err error
+ if network.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
+ sasl_external_key, enabled)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ RETURNING id`,
+ userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Network"
+ SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
+ connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
+ sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
+ enabled = $14
+ WHERE id = $1`,
+ network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
+ detach_on
+ FROM "Channel"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ key := toNullString(ch.Key)
+ detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
+
+ var err error
+ if ch.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
+ detach_after, detach_on)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+ RETURNING id`,
+ networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Channel"
+ SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
+ relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
+ WHERE id = $1`,
+ ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM "DeliveryReceipt"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx,
+ `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
+ networkID, client)
+ if err != nil {
+ return err
+ }
+
+ stmt, err := tx.PrepareContext(ctx, `
+ INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`)
+ if err != nil {
+ return err
+ }
+ defer stmt.Close()
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+ err := stmt.
+ QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
+ Scan(&rcpt.ID)
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`,
+ networkID, name)
+ if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ return receipt, nil
+}
+
+func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "ReadReceipt"
+ SET timestamp = $1
+ WHERE id = $2`,
+ receipt.Timestamp, receipt.ID)
+ } else {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "ReadReceipt" (network, target, timestamp)
+ VALUES ($1, $2, $3)
+ RETURNING id`,
+ networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID)
+ }
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "os"
+ "testing"
+)
+
+// PostgreSQL version 0 schema. DO NOT EDIT.
+const postgresV0Schema = `
+CREATE TABLE "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+
+INSERT INTO "Config" (id, version) VALUES (1, 1);
+
+CREATE TABLE "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TABLE "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA DEFAULT NULL,
+ sasl_external_key BYTEA DEFAULT NULL,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+
+CREATE TABLE "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+
+CREATE TABLE "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+`
+
+func openTempPostgresDB(t *testing.T) *sql.DB {
+ source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
+ if !ok {
+ t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
+ }
+
+ db, err := sql.Open("postgres", source)
+ if err != nil {
+ t.Fatalf("failed to connect to PostgreSQL: %v", err)
+ }
+
+ // Store all tables in a temporary schema which will be dropped when the
+ // connection to PostgreSQL is closed.
+ db.SetMaxOpenConns(1)
+ if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
+ t.Fatalf("failed to set PostgreSQL search_path: %v", err)
+ }
+
+ return db
+}
+
+func TestPostgresMigrations(t *testing.T) {
+ sqlDB := openTempPostgresDB(t)
+ if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &PostgresDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("PostgresDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "fmt"
+ "math"
+ "strings"
+ "sync"
+ "time"
+
+ _ "modernc.org/sqlite"
+)
+
+const sqliteQueryTimeout = 5 * time.Second
+
+//go:embed suika_sqlite_schema.sql
+var sqliteSchema string
+
+var sqliteMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
+ "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
+ "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
+ "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
+ "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
+ `
+ CREATE TABLE IF NOT EXISTS UserNew (
+ id INTEGER PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin INTEGER NOT NULL DEFAULT 0
+ );
+ INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
+ DROP TABLE User;
+ ALTER TABLE UserNew RENAME TO User;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user INTEGER NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BLOB DEFAULT NULL,
+ sasl_external_key BLOB DEFAULT NULL,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT Network.id, name, User.id as user, addr, nick,
+ Network.username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key
+ FROM Network
+ JOIN User ON Network.user = User.username;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255),
+ internal_msgid VARCHAR(255) NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+ );
+ `,
+ "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
+ "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
+ "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT id, name, user, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key,
+ enabled
+ FROM Network;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type SqliteDB struct {
+ lock sync.RWMutex
+ db *sql.DB
+}
+
+func OpenSqliteDB(source string) (Database, error) {
+ sqlSqliteDB, err := sql.Open("sqlite", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &SqliteDB{db: sqlSqliteDB}
+ if err := db.upgrade(); err != nil {
+ sqlSqliteDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *SqliteDB) Close() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+ return db.db.Close()
+}
+
+func (db *SqliteDB) upgrade() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ var version int
+ if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
+ return fmt.Errorf("failed to query schema version: %v", err)
+ }
+
+ if version == len(sqliteMigrations) {
+ return nil
+ } else if version > len(sqliteMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version)
+ }
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if version == 0 {
+ if _, err := tx.Exec(sqliteSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %v", err)
+ }
+ } else {
+ for i := version; i < len(sqliteMigrations); i++ {
+ if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ // For some reason prepared statements don't work here
+ _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM User) AS users,
+ (SELECT COUNT(*) FROM Network) AS networks,
+ (SELECT COUNT(*) FROM Channel) AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func toNullString(s string) sql.NullString {
+ return sql.NullString{
+ String: s,
+ Valid: s != "",
+ }
+}
+
+func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ "SELECT id, username, password, admin, realname FROM User")
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ "SELECT id, password, admin, realname FROM User WHERE username = ?",
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("username", user.Username),
+ sql.Named("password", toNullString(user.Password)),
+ sql.Named("admin", user.Admin),
+ sql.Named("realname", toNullString(user.Realname)),
+ }
+
+ var err error
+ if user.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE User SET password = :password, admin = :admin,
+ realname = :realname WHERE username = :username`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ User(username, password, admin, realname)
+ VALUES (:username, :password, :admin, :realname)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ user.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
+
+func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt
+ WHERE id IN (
+ SELECT DeliveryReceipt.id
+ FROM DeliveryReceipt
+ JOIN Network ON DeliveryReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt
+ WHERE id IN (
+ SELECT ReadReceipt.id
+ FROM ReadReceipt
+ JOIN Network ON ReadReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM Channel
+ WHERE id IN (
+ SELECT Channel.id
+ FROM Channel
+ JOIN Network ON Channel.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key, enabled
+ FROM Network
+ WHERE user = ?`,
+ userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ args := []interface{}{
+ sql.Named("name", toNullString(network.Name)),
+ sql.Named("addr", network.Addr),
+ sql.Named("nick", toNullString(network.Nick)),
+ sql.Named("username", toNullString(network.Username)),
+ sql.Named("realname", toNullString(network.Realname)),
+ sql.Named("pass", toNullString(network.Pass)),
+ sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
+ sql.Named("sasl_mechanism", saslMechanism),
+ sql.Named("sasl_plain_username", saslPlainUsername),
+ sql.Named("sasl_plain_password", saslPlainPassword),
+ sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
+ sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
+ sql.Named("enabled", network.Enabled),
+
+ sql.Named("id", network.ID), // only for UPDATE
+ sql.Named("user", userID), // only for INSERT
+ }
+
+ var err error
+ if network.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE Network
+ SET name = :name, addr = :addr, nick = :nick, username = :username,
+ realname = :realname, pass = :pass, connect_commands = :connect_commands,
+ sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
+ sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
+ enabled = :enabled
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO Network(user, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
+ VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
+ :connect_commands, :sasl_mechanism, :sasl_plain_username,
+ :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ network.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `SELECT
+ id, name, key, detached, detached_internal_msgid,
+ relay_detached, reattach_on, detach_after, detach_on
+ FROM Channel
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("network", networkID),
+ sql.Named("name", ch.Name),
+ sql.Named("key", toNullString(ch.Key)),
+ sql.Named("detached", ch.Detached),
+ sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)),
+ sql.Named("relay_detached", ch.RelayDetached),
+ sql.Named("reattach_on", ch.ReattachOn),
+ sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))),
+ sql.Named("detach_on", ch.DetachOn),
+
+ sql.Named("id", ch.ID), // only for UPDATE
+ }
+
+ var err error
+ if ch.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `UPDATE Channel
+ SET network = :network, name = :name, key = :key, detached = :detached,
+ detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached,
+ reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
+ VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...)
+ if err != nil {
+ return err
+ }
+ ch.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
+ return err
+}
+
+func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM DeliveryReceipt
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ var client sql.NullString
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ rcpt.Client = client.String
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
+ networkID, toNullString(client))
+ if err != nil {
+ return err
+ }
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+
+ res, err := tx.ExecContext(ctx, `
+ INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
+ VALUES (:network, :target, :client, :internal_msgid)`,
+ sql.Named("network", networkID),
+ sql.Named("target", rcpt.Target),
+ sql.Named("client", toNullString(client)),
+ sql.Named("internal_msgid", rcpt.InternalMsgID))
+ if err != nil {
+ return err
+ }
+ rcpt.ID, err = res.LastInsertId()
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx, `
+ SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`,
+ sql.Named("network", networkID),
+ sql.Named("target", name),
+ )
+ var timestamp string
+ if err := row.Scan(&receipt.ID, ×tamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if t, err := time.Parse(serverTimeLayout, timestamp); err != nil {
+ return nil, err
+ } else {
+ receipt.Timestamp = t
+ }
+ return receipt, nil
+}
+
+func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("id", receipt.ID),
+ sql.Named("timestamp", formatServerTime(receipt.Timestamp)),
+ sql.Named("network", networkID),
+ sql.Named("target", receipt.Target),
+ }
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ ReadReceipt(network, target, timestamp)
+ VALUES (:network, :target, :timestamp)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ receipt.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "testing"
+)
+
+// SQLite version 0 schema. DO NOT EDIT.
+const sqliteV0Schema = `
+CREATE TABLE User (
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255)
+);
+
+CREATE TABLE Network (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user VARCHAR(255) NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+
+CREATE TABLE Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+PRAGMA user_version = 1;
+`
+
+func TestSqliteMigrations(t *testing.T) {
+ sqlDB, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+
+ if _, err := sqlDB.Exec(sqliteV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &SqliteDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("SqliteDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go.
+//
+// # Copyright (C) 2020 The soju Contributors
+// # Copyright (C) 2023-present Izuru Yakumo et al.
+//
+// suika is covered by the AGPLv3 license:
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <https://www.gnu.org/licenses/>.
+package suika
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-CONFIG 5
+.Os
+.Sh NAME
+.Nm suika-config
+.Nd Configuration file for the IRC bouncer
+.Sh SYNOPSIS
+.Bk -words
+listen ircs://
+.Pp
+tls cert.pem key.pem
+.Pp
+hostname example.org
+.Ek
+.Sh DESCRIPTION
+This document describes the format of the configuration
+file used by
+.Xr suika 1
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It listen Ar uri
+With this you can control on what
+ports/protocols
+.Xr suika 1
+listens on, it supports
+irc (cleartext IRC), ircs (IRC with TLS), and unix
+(IRC over Unix domain sockets)
+.It hostname Ar hostname
+Server hostname, if unset, the system one is used.
+.It title Ar title
+Server title, this will be sent as the ISUPPORT NETWORK value when
+clients don't select a specific network.
+.It tls Ar cert Ar key
+Enable TLS support, the certificate and key files must be
+PEM-encoded.
+.It db Ar driver Ar path
+Set the database driver for user, network and channel storage.
+By default a SQLite 3 database is opened in
+.Pa ./suika.db
+Supported drivers are sqlite and postgres, the former
+expects a path to the database file, and the latter
+a space-separated list of key=value parameters,
+e.g. host=localhost dbname=suika
+.It log fs Ar path
+Path to the bouncer logs directory, or empty to disable
+logging.
+By default, logging is disabled.
+.It max-user-networks Ar limit
+Maximum number of networks per user, by default
+there is no limit.
+.It motd Ar path
+Path to the MOTD file, its contents are sent to clients
+which aren't bound to a particular network.
+By default, no MOTD is sent.
+.It multi-upstream-mode Ar bool
+Globally enable or disable multi-upstream mode.
+By default, it is enabled.
+.It upstream-user-ip Ar cidr
+Enable per-user-IP addresses.
+One IPv4 and/or one IPv6 range can be specified in CIDR notation.
+One IP address per range will be assigned to each user as the
+source address when connecting to an upstream network.
+This can be useful to avoid having the whole bouncer banned from
+an upstream network because of one malicious user.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-ZNC-IMPORT 1
+.Os
+.Sh NAME
+.Nm suika-znc-import
+.Nd Migration utility for moving from ZNC
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar suika config file
+.Op Fl user Ar username
+.Op Fl network Ar name
+.Sh DESCRIPTION
+Imports configuration from a ZNC file.
+Users and networks are merged if they already exist in the
+.Xr suika 1
+database.
+ZNC settings overwrite existing
+.Xr suika 1
+settings
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It config Ar suika config file
+Path to
+.Xr suika-config 5
+.It user Ar username
+Limit import to username, may be specified multiple times.
+It network Ar name
+Limit import to network, may be specified multiple times.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA 1
+.Os
+.Sh NAME
+.Nm suika
+.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar path
+.Op Fl debug
+.Op Fl listen Ar uri
+.Sh DESCRIPTION
+.Nm
+is an user-friendly IRC bouncer, it connects to upstream
+IRC servers on behalf of the user to provide extra features.
+.Bl -tag -width 6n
+.It Multiple separate users sharing the same bouncer
+.It Clients connecting to multiple upstream servers (via a single connection)
+.It Sending the backlog with per-client buffers
+.El
+.Pp
+When joining a channel, the channel will be saved
+and automatically joined on the next connection.
+When registering or authenticating with NickServ, the credentials will be saved
+and automatically used on the next connection if the server supports SASL.
+When parting a channel with the reason "detach", the channel will be
+detached instead of being left.
+When all clients are disconnected from the bouncer,
+the user is automatically marked as away.
+.Pp
+.Nm
+supports two connection modes:
+.Bl -tag -width 6n
+.It Single upstream mode
+One downstream connection maps to one upstream connection
+.Pp
+To enable this mode, connect to the bouncer
+with the username "<username>/<network>".
+.Pp
+If the bouncer isn't connected to the upstream server,
+it will get automatically added.
+.Pp
+Then channels can be joined and parted as if
+you were directly connected to the upstream server.
+.It Multiple upstream mode
+One downstream connection maps to multiple upstream connections.
+Channels and nicks are suffixed with the network name.
+To join a channel, you need to use the suffix too: /join #channel/network.
+Same applies to messages sent to users.
+.El
+.Pp
+For per-client history to work, clients need to indicate their name.
+This can be done by adding a "@<client>" suffix to the username.
+.Pp
+.Nm
+will reload the configuration file, the TLS certificate/key and
+the MOTD file when it receives the HUP signal.
+The configuration options listen, db and log cannot be reloaded.
+.Pp
+Administrators can broadcast a message to all bouncer users via
+/notice $<hostname> <text>, or via /notice $<text> in multi-upstream mode.
+All currently connected bouncer users will receive the message
+from the special BouncerServ service.
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKADB 1
+.Os
+.Sh NAME
+.Nm suikadb
+.Nd Basic user manipulation for
+.Xr suika 1
+.Sh SYNOPSIS
+.Nm
+.Op create-user
+.Op change-password
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+type ircError struct {
+ Message *irc.Message
+}
+
+func (err ircError) Error() string {
+ return err.Message.String()
+}
+
+func newUnknownCommandError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{
+ "*",
+ cmd,
+ "Unknown command",
+ },
+ }}
+}
+
+func newNeedMoreParamsError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NEEDMOREPARAMS,
+ Params: []string{
+ "*",
+ cmd,
+ "Not enough parameters",
+ },
+ }}
+}
+
+func newChatHistoryError(subcommand string, target string) ircError {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"},
+ }}
+}
+
+// authError is an authentication error.
+type authError struct {
+ // Internal error cause. This will not be revealed to the user.
+ err error
+ // Error cause which can safely be sent to the user without compromising
+ // security.
+ reason string
+}
+
+func (err *authError) Error() string {
+ return err.err.Error()
+}
+
+func (err *authError) Unwrap() error {
+ return err.err
+}
+
+// authErrorReason returns the user-friendly reason of an authentication
+// failure.
+func authErrorReason(err error) string {
+ if authErr, ok := err.(*authError); ok {
+ return authErr.reason
+ } else {
+ return "Authentication failed"
+ }
+}
+
+func newInvalidUsernameOrPasswordError(err error) error {
+ return &authError{
+ err: err,
+ reason: "Invalid username or password",
+ }
+}
+
+func parseBouncerNetID(subcommand, s string) (int64, error) {
+ id, err := strconv.ParseInt(s, 10, 64)
+ if err != nil {
+ return 0, ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"},
+ }}
+ }
+ return id, nil
+}
+
+func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) {
+ u, err := network.URL()
+ if err != nil {
+ return
+ }
+
+ hasHostPort := true
+ switch u.Scheme {
+ case "ircs":
+ attrs["tls"] = irc.TagValue("1")
+ case "irc":
+ attrs["tls"] = irc.TagValue("0")
+ default: // e.g. unix://
+ hasHostPort = false
+ }
+ if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort {
+ attrs["host"] = irc.TagValue(host)
+ attrs["port"] = irc.TagValue(port)
+ } else if hasHostPort {
+ attrs["host"] = irc.TagValue(u.Host)
+ }
+}
+
+func getNetworkAttrs(network *network) irc.Tags {
+ state := "disconnected"
+ if uc := network.conn; uc != nil {
+ state = "connected"
+ }
+
+ attrs := irc.Tags{
+ "name": irc.TagValue(network.GetName()),
+ "state": irc.TagValue(state),
+ "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)),
+ }
+
+ if network.Username != "" {
+ attrs["username"] = irc.TagValue(network.Username)
+ }
+ if realname := GetRealname(&network.user.User, &network.Network); realname != "" {
+ attrs["realname"] = irc.TagValue(realname)
+ }
+
+ fillNetworkAddrAttrs(attrs, &network.Network)
+
+ return attrs
+}
+
+func networkAddrFromAttrs(attrs irc.Tags) string {
+ host, ok := attrs.GetTag("host")
+ if !ok {
+ return ""
+ }
+
+ addr := host
+ if port, ok := attrs.GetTag("port"); ok {
+ addr += ":" + port
+ }
+
+ if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" {
+ addr = "irc://" + tlsStr
+ }
+
+ return addr
+}
+
+func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error {
+ addrAttrs := irc.Tags{}
+ fillNetworkAddrAttrs(addrAttrs, record)
+
+ updateAddr := false
+ for k, v := range attrs {
+ s := string(v)
+ switch k {
+ case "host", "port", "tls":
+ updateAddr = true
+ addrAttrs[k] = v
+ case "name":
+ record.Name = s
+ case "nickname":
+ record.Nick = s
+ case "username":
+ record.Username = s
+ case "realname":
+ record.Realname = s
+ case "pass":
+ record.Pass = s
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"},
+ }}
+ }
+ }
+
+ if updateAddr {
+ record.Addr = networkAddrFromAttrs(addrAttrs)
+ if record.Addr == "" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"},
+ }}
+ }
+ }
+
+ return nil
+}
+
+// illegalNickChars is the list of characters forbidden in a nickname.
+//
+// ' ' and ':' break the IRC message wire format
+// '@' and '!' break prefixes
+// '*' breaks masks and is the reserved nickname for registration
+// '?' breaks masks
+// '$' breaks server masks in PRIVMSG/NOTICE
+// ',' breaks lists
+// '.' is reserved for server names
+const illegalNickChars = " :@!*?$,."
+
+// permanentDownstreamCaps is the list of always-supported downstream
+// capabilities.
+var permanentDownstreamCaps = map[string]string{
+ "batch": "",
+ "cap-notify": "",
+ "echo-message": "",
+ "invite-notify": "",
+ "message-tags": "",
+ "server-time": "",
+ "setname": "",
+
+ "soju.im/bouncer-networks": "",
+ "soju.im/bouncer-networks-notify": "",
+ "soju.im/read": "",
+}
+
+// needAllDownstreamCaps is the list of downstream capabilities that
+// require support from all upstreams to be enabled
+var needAllDownstreamCaps = map[string]string{
+ "account-notify": "",
+ "account-tag": "",
+ "away-notify": "",
+ "extended-join": "",
+ "multi-prefix": "",
+
+ "draft/extended-monitor": "",
+}
+
+// passthroughIsupport is the set of ISUPPORT tokens that are directly passed
+// through from the upstream server to downstream clients.
+//
+// This is only effective in single-upstream mode.
+var passthroughIsupport = map[string]bool{
+ "AWAYLEN": true,
+ "BOT": true,
+ "CHANLIMIT": true,
+ "CHANMODES": true,
+ "CHANNELLEN": true,
+ "CHANTYPES": true,
+ "CLIENTTAGDENY": true,
+ "ELIST": true,
+ "EXCEPTS": true,
+ "EXTBAN": true,
+ "HOSTLEN": true,
+ "INVEX": true,
+ "KICKLEN": true,
+ "MAXLIST": true,
+ "MAXTARGETS": true,
+ "MODES": true,
+ "MONITOR": true,
+ "NAMELEN": true,
+ "NETWORK": true,
+ "NICKLEN": true,
+ "PREFIX": true,
+ "SAFELIST": true,
+ "TARGMAX": true,
+ "TOPICLEN": true,
+ "USERLEN": true,
+ "UTF8ONLY": true,
+ "WHOX": true,
+}
+
+type downstreamSASL struct {
+ server sasl.Server
+ plainUsername, plainPassword string
+ pendingResp bytes.Buffer
+}
+
+type downstreamConn struct {
+ conn
+
+ id uint64
+
+ registered bool
+ user *user
+ nick string
+ nickCM string
+ rawUsername string
+ networkName string
+ clientName string
+ realname string
+ hostname string
+ account string // RPL_LOGGEDIN/OUT state
+ password string // empty after authentication
+ network *network // can be nil
+ isMultiUpstream bool
+
+ negotiatingCaps bool
+ capVersion int
+ supportedCaps map[string]string
+ caps map[string]bool
+ sasl *downstreamSASL
+
+ lastBatchRef uint64
+
+ monitored casemapMap
+}
+
+func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
+ remoteAddr := ic.RemoteAddr().String()
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
+ options := connOptions{Logger: logger}
+ dc := &downstreamConn{
+ conn: *newConn(srv, ic, &options),
+ id: id,
+ nick: "*",
+ nickCM: "*",
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ monitored: newCasemapMap(0),
+ }
+ dc.hostname = remoteAddr
+ if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
+ dc.hostname = host
+ }
+ for k, v := range permanentDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ dc.supportedCaps["sasl"] = "PLAIN"
+ // TODO: this is racy, we should only enable chathistory after
+ // authentication and then check that user.msgStore implements
+ // chatHistoryMessageStore
+ if srv.Config().LogPath != "" {
+ dc.supportedCaps["draft/chathistory"] = ""
+ }
+ return dc
+}
+
+func (dc *downstreamConn) prefix() *irc.Prefix {
+ return &irc.Prefix{
+ Name: dc.nick,
+ User: dc.user.Username,
+ Host: dc.hostname,
+ }
+}
+
+func (dc *downstreamConn) forEachNetwork(f func(*network)) {
+ if dc.network != nil {
+ f(dc.network)
+ } else if dc.isMultiUpstream {
+ for _, network := range dc.user.networks {
+ f(network)
+ }
+ }
+}
+
+func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ dc.user.forEachUpstream(func(uc *upstreamConn) {
+ if dc.network != nil && uc.network != dc.network {
+ return
+ }
+ f(uc)
+ })
+}
+
+// upstream returns the upstream connection, if any. If there are zero or if
+// there are multiple upstream connections, it returns nil.
+func (dc *downstreamConn) upstream() *upstreamConn {
+ if dc.network == nil {
+ return nil
+ }
+ return dc.network.conn
+}
+
+func isOurNick(net *network, nick string) bool {
+ // TODO: this doesn't account for nick changes
+ if net.conn != nil {
+ return net.casemap(nick) == net.conn.nickCM
+ }
+ // We're not currently connected to the upstream connection, so we don't
+ // know whether this name is our nickname. Best-effort: use the network's
+ // configured nickname and hope it was the one being used when we were
+ // connected.
+ return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network))
+}
+
+// marshalEntity converts an upstream entity name (ie. channel or nick) into a
+// downstream entity name.
+//
+// This involves adding a "/<network>" suffix if the entity isn't the current
+// user.
+func (dc *downstreamConn) marshalEntity(net *network, name string) string {
+ if isOurNick(net, name) {
+ return dc.nick
+ }
+ name = partialCasemap(net.casemap, name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal an entity for another network")
+ }
+ return name
+ }
+ return name + "/" + net.GetName()
+}
+
+func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix {
+ if isOurNick(net, prefix.Name) {
+ return dc.prefix()
+ }
+ prefix.Name = partialCasemap(net.casemap, prefix.Name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal a user prefix for another network")
+ }
+ return prefix
+ }
+ return &irc.Prefix{
+ Name: prefix.Name + "/" + net.GetName(),
+ User: prefix.User,
+ Host: prefix.Host,
+ }
+}
+
+// unmarshalEntityNetwork converts a downstream entity name (ie. channel or
+// nick) into an upstream entity name.
+//
+// This involves removing the "/<network>" suffix.
+func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) {
+ if dc.network != nil {
+ return dc.network, name, nil
+ }
+ if !dc.isMultiUpstream {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"},
+ }}
+ }
+
+ var net *network
+ if i := strings.LastIndexByte(name, '/'); i >= 0 {
+ network := name[i+1:]
+ name = name[:i]
+
+ for _, n := range dc.user.networks {
+ if network == n.GetName() {
+ net = n
+ break
+ }
+ }
+ }
+
+ if net == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Missing network suffix in name"},
+ }}
+ }
+
+ return net, name, nil
+}
+
+// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the
+// upstream connection and fails if the upstream is disconnected.
+func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) {
+ net, name, err := dc.unmarshalEntityNetwork(name)
+ if err != nil {
+ return nil, "", err
+ }
+
+ if net.conn == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Disconnected from upstream network"},
+ }}
+ }
+
+ return net.conn, name, nil
+}
+
+func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string {
+ if dc.upstream() != nil {
+ return text
+ }
+ // TODO: smarter parsing that ignores URLs
+ return strings.ReplaceAll(text, "/"+uc.network.GetName(), "")
+}
+
+func (dc *downstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := dc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (dc *downstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := dc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventDownstreamMessage{msg, dc}
+ }
+
+ return nil
+}
+
+// SendMessage sends an outgoing message.
+//
+// This can only called from the user goroutine.
+func (dc *downstreamConn) SendMessage(msg *irc.Message) {
+ if !dc.caps["message-tags"] {
+ if msg.Command == "TAGMSG" {
+ return
+ }
+ msg = msg.Copy()
+ for name := range msg.Tags {
+ supported := false
+ switch name {
+ case "time":
+ supported = dc.caps["server-time"]
+ case "account":
+ supported = dc.caps["account"]
+ }
+ if !supported {
+ delete(msg.Tags, name)
+ }
+ }
+ }
+ if !dc.caps["batch"] && msg.Tags["batch"] != "" {
+ msg = msg.Copy()
+ delete(msg.Tags, "batch")
+ }
+ if msg.Command == "JOIN" && !dc.caps["extended-join"] {
+ msg.Params = msg.Params[:1]
+ }
+ if msg.Command == "SETNAME" && !dc.caps["setname"] {
+ return
+ }
+ if msg.Command == "AWAY" && !dc.caps["away-notify"] {
+ return
+ }
+ if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] {
+ return
+ }
+ if msg.Command == "READ" && !dc.caps["soju.im/read"] {
+ return
+ }
+
+ dc.conn.SendMessage(context.TODO(), msg)
+}
+
+func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) {
+ dc.lastBatchRef++
+ ref := fmt.Sprintf("%v", dc.lastBatchRef)
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Tags: tags,
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: append([]string{"+" + ref, typ}, params...),
+ })
+ }
+
+ f(irc.TagValue(ref))
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: []string{"-" + ref},
+ })
+ }
+}
+
+// sendMessageWithID sends an outgoing message with the specified internal ID.
+func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
+ dc.SendMessage(msg)
+
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// advanceMessageWithID advances history to the specified message ID without
+// sending a message. This is useful e.g. for self-messages when echo-message
+// isn't enabled.
+func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// ackMsgID acknowledges that a message has been received.
+func (dc *downstreamConn) ackMsgID(id string) {
+ netID, entity, err := parseMsgID(id, nil)
+ if err != nil {
+ dc.logger.Printf("failed to ACK message ID %q: %v", id, err)
+ return
+ }
+
+ network := dc.user.getNetworkByID(netID)
+ if network == nil {
+ return
+ }
+
+ network.delivered.StoreID(entity, dc.clientName, id)
+}
+
+func (dc *downstreamConn) sendPing(msgID string) {
+ token := "suika-msgid-" + msgID
+ dc.SendMessage(&irc.Message{
+ Command: "PING",
+ Params: []string{token},
+ })
+}
+
+func (dc *downstreamConn) handlePong(token string) {
+ if !strings.HasPrefix(token, "suika-msgid-") {
+ dc.logger.Printf("received unrecognized PONG token %q", token)
+ return
+ }
+ msgID := strings.TrimPrefix(token, "suika-msgid-")
+ dc.ackMsgID(msgID)
+}
+
+// marshalMessage re-formats a message coming from an upstream connection so
+// that it's suitable for being sent on this downstream connection. Only
+// messages that may appear in logs are supported, except MODE messages which
+// may only appear in single-upstream mode.
+func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message {
+ msg = msg.Copy()
+ msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix)
+
+ if dc.network != nil {
+ return msg
+ }
+
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE", "TAGMSG":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "NICK":
+ // Nick change for another user
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "JOIN", "PART":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "KICK":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ msg.Params[1] = dc.marshalEntity(net, msg.Params[1])
+ case "TOPIC":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "QUIT", "SETNAME":
+ // This space is intentionally left blank
+ default:
+ panic(fmt.Sprintf("unexpected %q message", msg.Command))
+ }
+
+ return msg
+}
+
+func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ ctx, cancel := dc.conn.NewContext(ctx)
+ defer cancel()
+
+ ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout)
+ defer cancel()
+
+ switch msg.Command {
+ case "QUIT":
+ return dc.Close()
+ default:
+ if dc.registered {
+ return dc.handleMessageRegistered(ctx, msg)
+ } else {
+ return dc.handleMessageUnregistered(ctx, msg)
+ }
+ }
+}
+
+func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "NICK":
+ var nick string
+ if err := parseMessageParams(msg, &nick); err != nil {
+ return err
+ }
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, nick, "contains illegal characters"},
+ }}
+ }
+ nickCM := casemapASCII(nick)
+ if nickCM == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
+ }}
+ }
+ dc.nick = nick
+ dc.nickCM = nickCM
+ case "USER":
+ if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil {
+ return err
+ }
+ case "PASS":
+ if err := parseMessageParams(msg, &dc.password); err != nil {
+ return err
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "AUTHENTICATE":
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ } else if credentials == nil {
+ break
+ }
+
+ if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil {
+ dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err)
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, authErrorReason(err)},
+ })
+ break
+ }
+
+ // Technically we should send RPL_LOGGEDIN here. However we use
+ // RPL_LOGGEDIN to mirror the upstream connection status. Let's
+ // see how many clients that breaks. See:
+ // https://github.com/ircv3/ircv3-specifications/pull/476
+ dc.endSASL(nil)
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+
+ if dc.user == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
+ }}
+ }
+
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ var match *network
+ for _, net := range dc.user.networks {
+ if net.ID == id {
+ match = net
+ break
+ }
+ }
+ if match == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"},
+ }}
+ }
+
+ dc.networkName = match.GetName()
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+ return newUnknownCommandError(msg.Command)
+ }
+ if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps {
+ return dc.register(ctx)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
+ cmd = strings.ToUpper(cmd)
+
+ switch cmd {
+ case "LS":
+ if len(args) > 0 {
+ var err error
+ if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
+ return err
+ }
+ }
+ if !dc.registered && dc.capVersion >= 302 {
+ // Let downstream show everything it supports, and trim
+ // down the available capabilities when upstreams are
+ // known.
+ for k, v := range needAllDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ }
+
+ caps := make([]string, 0, len(dc.supportedCaps))
+ for k, v := range dc.supportedCaps {
+ if dc.capVersion >= 302 && v != "" {
+ caps = append(caps, k+"="+v)
+ } else {
+ caps = append(caps, k)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LS", strings.Join(caps, " ")},
+ })
+
+ if dc.capVersion >= 302 {
+ // CAP version 302 implicitly enables cap-notify
+ dc.caps["cap-notify"] = true
+ }
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "LIST":
+ var caps []string
+ for name, enabled := range dc.caps {
+ if enabled {
+ caps = append(caps, name)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LIST", strings.Join(caps, " ")},
+ })
+ case "REQ":
+ if len(args) == 0 {
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"},
+ }}
+ }
+
+ // TODO: atomically ack/nak the whole capability set
+ caps := strings.Fields(args[0])
+ ack := true
+ for _, name := range caps {
+ name = strings.ToLower(name)
+ enable := !strings.HasPrefix(name, "-")
+ if !enable {
+ name = strings.TrimPrefix(name, "-")
+ }
+
+ if enable == dc.caps[name] {
+ continue
+ }
+
+ _, ok := dc.supportedCaps[name]
+ if !ok {
+ ack = false
+ break
+ }
+
+ if name == "cap-notify" && dc.capVersion >= 302 && !enable {
+ // cap-notify cannot be disabled with CAP version 302
+ ack = false
+ break
+ }
+
+ dc.caps[name] = enable
+ }
+
+ reply := "NAK"
+ if ack {
+ reply = "ACK"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, reply, args[0]},
+ })
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "END":
+ dc.negotiatingCaps = false
+ default:
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Unknown CAP command"},
+ }}
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) {
+ defer func() {
+ if err != nil {
+ dc.sasl = nil
+ }
+ }()
+
+ if !dc.caps["sasl"] {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"},
+ }}
+ }
+ if len(msg.Params) == 0 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Missing AUTHENTICATE argument"},
+ }}
+ }
+ if msg.Params[0] == "*" {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ }}
+ }
+
+ var resp []byte
+ if dc.sasl == nil {
+ mech := strings.ToUpper(msg.Params[0])
+ var server sasl.Server
+ switch mech {
+ case "PLAIN":
+ server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
+ dc.sasl.plainUsername = username
+ dc.sasl.plainPassword = password
+ return nil
+ }))
+ default:
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)},
+ }}
+ }
+
+ dc.sasl = &downstreamSASL{server: server}
+ } else {
+ chunk := msg.Params[0]
+ if chunk == "+" {
+ chunk = ""
+ }
+
+ if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Response too long"},
+ }}
+ }
+
+ dc.sasl.pendingResp.WriteString(chunk)
+
+ if len(chunk) == maxSASLLength {
+ return nil, nil // Multi-line response, wait for the next command
+ }
+
+ resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String())
+ if err != nil {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Invalid base64-encoded response"},
+ }}
+ }
+
+ dc.sasl.pendingResp.Reset()
+ }
+
+ challenge, done, err := dc.sasl.server.Next(resp)
+ if err != nil {
+ return nil, err
+ } else if done {
+ return dc.sasl, nil
+ } else {
+ challengeStr := "+"
+ if len(challenge) > 0 {
+ challengeStr = base64.StdEncoding.EncodeToString(challenge)
+ }
+
+ // TODO: multi-line messages
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "AUTHENTICATE",
+ Params: []string{challengeStr},
+ })
+ return nil, nil
+ }
+}
+
+func (dc *downstreamConn) endSASL(msg *irc.Message) {
+ if dc.sasl == nil {
+ return
+ }
+
+ dc.sasl = nil
+
+ if msg != nil {
+ dc.SendMessage(msg)
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_SASLSUCCESS,
+ Params: []string{dc.nick, "SASL authentication successful"},
+ })
+ }
+}
+
+func (dc *downstreamConn) setSupportedCap(name, value string) {
+ prevValue, hasPrev := dc.supportedCaps[name]
+ changed := !hasPrev || prevValue != value
+ dc.supportedCaps[name] = value
+
+ if !dc.caps["cap-notify"] || !changed {
+ return
+ }
+
+ cap := name
+ if value != "" && dc.capVersion >= 302 {
+ cap = name + "=" + value
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "NEW", cap},
+ })
+}
+
+func (dc *downstreamConn) unsetSupportedCap(name string) {
+ _, hasPrev := dc.supportedCaps[name]
+ delete(dc.supportedCaps, name)
+ delete(dc.caps, name)
+
+ if !dc.caps["cap-notify"] || !hasPrev {
+ return
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "DEL", name},
+ })
+}
+
+func (dc *downstreamConn) updateSupportedCaps() {
+ supportedCaps := make(map[string]bool)
+ for cap := range needAllDownstreamCaps {
+ supportedCaps[cap] = true
+ }
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for cap, supported := range supportedCaps {
+ supportedCaps[cap] = supported && uc.caps[cap]
+ }
+ })
+
+ for cap, supported := range supportedCaps {
+ if supported {
+ dc.setSupportedCap(cap, needAllDownstreamCaps[cap])
+ } else {
+ dc.unsetSupportedCap(cap)
+ }
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") {
+ dc.setSupportedCap("sasl", "PLAIN")
+ } else if dc.network != nil {
+ dc.unsetSupportedCap("sasl")
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] {
+ // Strip "before-connect", because we require downstreams to be fully
+ // connected before attempting account registration.
+ values := strings.Split(uc.supportedCaps["draft/account-registration"], ",")
+ for i, v := range values {
+ if v == "before-connect" {
+ values = append(values[:i], values[i+1:]...)
+ break
+ }
+ }
+ dc.setSupportedCap("draft/account-registration", strings.Join(values, ","))
+ } else {
+ dc.unsetSupportedCap("draft/account-registration")
+ }
+
+ if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil {
+ dc.setSupportedCap("draft/event-playback", "")
+ } else {
+ dc.unsetSupportedCap("draft/event-playback")
+ }
+}
+
+func (dc *downstreamConn) updateNick() {
+ if uc := dc.upstream(); uc != nil && uc.nick != dc.nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ dc.nick = uc.nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+}
+
+func (dc *downstreamConn) updateRealname() {
+ if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{uc.realname},
+ })
+ dc.realname = uc.realname
+ }
+}
+
+func (dc *downstreamConn) updateAccount() {
+ var account string
+ if dc.network == nil {
+ account = dc.user.Username
+ } else if uc := dc.upstream(); uc != nil {
+ account = uc.account
+ } else {
+ return
+ }
+
+ if dc.account == account || !dc.caps["sasl"] {
+ return
+ }
+
+ if account != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDIN,
+ Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDOUT,
+ Params: []string{dc.nick, dc.prefix().String(), "You are logged out"},
+ })
+ }
+
+ dc.account = account
+}
+
+func sanityCheckServer(ctx context.Context, addr string) error {
+ ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
+ defer cancel()
+
+ conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ return conn.Close()
+}
+
+func unmarshalUsername(rawUsername string) (username, client, network string) {
+ username = rawUsername
+
+ i := strings.IndexAny(username, "/@")
+ j := strings.LastIndexAny(username, "/@")
+ if i >= 0 {
+ username = rawUsername[:i]
+ }
+ if j >= 0 {
+ if rawUsername[j] == '@' {
+ client = rawUsername[j+1:]
+ } else {
+ network = rawUsername[j+1:]
+ }
+ }
+ if i >= 0 && j >= 0 && i < j {
+ if rawUsername[i] == '@' {
+ client = rawUsername[i+1 : j]
+ } else {
+ network = rawUsername[i+1 : j]
+ }
+ }
+
+ return username, client, network
+}
+
+func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
+ username, clientName, networkName := unmarshalUsername(username)
+
+ u, err := dc.srv.db.GetUser(ctx, username)
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err))
+ }
+
+ // Password auth disabled
+ if u.Password == "" {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled"))
+ }
+
+ err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password"))
+ }
+
+ dc.user = dc.srv.getUser(username)
+ if dc.user == nil {
+ return fmt.Errorf("user not active")
+ }
+ dc.clientName = clientName
+ dc.networkName = networkName
+ return nil
+}
+
+func (dc *downstreamConn) register(ctx context.Context) error {
+ if dc.registered {
+ panic("tried to register twice")
+ }
+
+ if dc.sasl != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ }
+
+ password := dc.password
+ dc.password = ""
+ if dc.user == nil {
+ if password == "" {
+ if dc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"},
+ }}
+ } else {
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, "Authentication required"},
+ }}
+ }
+ }
+
+ if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil {
+ dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, authErrorReason(err)},
+ }}
+ }
+ }
+
+ _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername)
+ if dc.clientName == "" {
+ dc.clientName = fallbackClientName
+ } else if fallbackClientName != "" && dc.clientName != fallbackClientName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Client name mismatch in usernames"},
+ }}
+ }
+ if dc.networkName == "" {
+ dc.networkName = fallbackNetworkName
+ } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Network name mismatch in usernames"},
+ }}
+ }
+
+ dc.registered = true
+ dc.logger.Printf("registration complete for user %q", dc.user.Username)
+ return nil
+}
+
+func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
+ if dc.networkName == "" {
+ return nil
+ }
+
+ network := dc.user.getNetwork(dc.networkName)
+ if network == nil {
+ addr := dc.networkName
+ if !strings.ContainsRune(addr, ':') {
+ addr = addr + ":6697"
+ }
+
+ dc.logger.Printf("trying to connect to new network %q", addr)
+ if err := sanityCheckServer(ctx, addr); err != nil {
+ dc.logger.Printf("failed to connect to %q: %v", addr, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)},
+ }}
+ }
+
+ // Some clients only allow specifying the nickname (and use the
+ // nickname as a username too). Strip the network name from the
+ // nickname when auto-saving networks.
+ nick, _, _ := unmarshalUsername(dc.nick)
+
+ dc.logger.Printf("auto-saving network %q", dc.networkName)
+ var err error
+ network, err = dc.user.createNetwork(ctx, &Network{
+ Addr: dc.networkName,
+ Nick: nick,
+ Enabled: true,
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ dc.network = network
+ return nil
+}
+
+func (dc *downstreamConn) welcome(ctx context.Context) error {
+ if dc.user == nil || !dc.registered {
+ panic("tried to welcome an unregistered connection")
+ }
+
+ remoteAddr := dc.conn.RemoteAddr().String()
+ dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}
+
+ // TODO: doing this might take some time. We should do it in dc.register
+ // instead, but we'll potentially be adding a new network and this must be
+ // done in the user goroutine.
+ if err := dc.loadNetwork(ctx); err != nil {
+ return err
+ }
+
+ if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream {
+ dc.isMultiUpstream = true
+ }
+
+ dc.updateSupportedCaps()
+
+ isupport := []string{
+ fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit),
+ "CASEMAPPING=ascii",
+ }
+
+ if dc.network != nil {
+ isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID))
+ }
+ if title := dc.srv.Config().Title; dc.network == nil && title != "" {
+ isupport = append(isupport, "NETWORK="+encodeISUPPORT(title))
+ }
+ if dc.network == nil && !dc.isMultiUpstream {
+ isupport = append(isupport, "WHOX")
+ }
+
+ if uc := dc.upstream(); uc != nil {
+ for k := range passthroughIsupport {
+ v, ok := uc.isupport[k]
+ if !ok {
+ continue
+ }
+ if v != nil {
+ isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v))
+ } else {
+ isupport = append(isupport, k)
+ }
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WELCOME,
+ Params: []string{dc.nick, "Welcome to suika, " + dc.nick},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_YOURHOST,
+ Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MYINFO,
+ Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) {
+ dc.SendMessage(msg)
+ }
+ if uc := dc.upstream(); uc != nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + string(uc.modes)},
+ })
+ }
+ if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+o"},
+ })
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+
+ if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil {
+ for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) {
+ dc.SendMessage(msg)
+ }
+ } else {
+ motdHint := "No MOTD"
+ if dc.network != nil {
+ motdHint = "Use /motd to read the message of the day"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOMOTD,
+ Params: []string{dc.nick, motdHint},
+ })
+ }
+
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if !ch.complete {
+ continue
+ }
+ record := uc.network.channels.Value(ch.Name)
+ if record != nil && record.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)},
+ })
+
+ forwardChannel(ctx, dc, ch)
+ }
+ })
+
+ dc.forEachNetwork(func(net *network) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ // Only send history if we're the first connected client with that name
+ // for the network
+ firstClient := true
+ dc.user.forEachDownstream(func(c *downstreamConn) {
+ if c != dc && c.clientName == dc.clientName && c.network == dc.network {
+ firstClient = false
+ }
+ })
+ if firstClient {
+ net.delivered.ForEachTarget(func(target string) {
+ lastDelivered := net.delivered.LoadID(target, dc.clientName)
+ if lastDelivered == "" {
+ return
+ }
+
+ dc.sendTargetBacklog(ctx, net, target, lastDelivered)
+
+ // Fast-forward history to last message
+ targetCM := net.casemap(target)
+ lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now())
+ if err != nil {
+ dc.logger.Printf("failed to get last message ID: %v", err)
+ return
+ }
+ net.delivered.StoreID(target, dc.clientName, lastID)
+ })
+ }
+ })
+
+ return nil
+}
+
+// messageSupportsBacklog checks whether the provided message can be sent as
+// part of an history batch.
+func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool {
+ // Don't replay all messages, because that would mess up client
+ // state. For instance we just sent the list of users, sending
+ // PART messages for one of these users would be incorrect.
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ return true
+ }
+ return false
+}
+
+func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ ch := net.channels.Value(target)
+
+ ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
+ defer cancel()
+
+ targetCM := net.casemap(target)
+ history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit)
+ if err != nil {
+ dc.logger.Printf("failed to send backlog for %q: %v", target, err)
+ return
+ }
+
+ dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ if ch != nil && ch.Detached {
+ if net.detachedMessageNeedsRelay(ch, msg) {
+ dc.relayDetachedMessage(net, msg)
+ }
+ } else {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, net))
+ }
+ }
+ })
+}
+
+func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return
+ }
+
+ sender := msg.Prefix.Name
+ target, text := msg.Params[0], msg.Params[1]
+ if net.isHighlight(msg) {
+ sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ } else {
+ sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ }
+}
+
+func (dc *downstreamConn) runUntilRegistered() error {
+ ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout)
+ defer cancel()
+
+ // Close the connection with an error if the deadline is exceeded
+ go func() {
+ <-ctx.Done()
+ if err := ctx.Err(); err == context.DeadlineExceeded {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "ERROR",
+ Params: []string{"Connection registration timed out"},
+ })
+ dc.Close()
+ }
+ }()
+
+ for !dc.registered {
+ msg, err := dc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read IRC command: %w", err)
+ }
+
+ err = dc.handleMessage(ctx, msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
+ }
+ }
+
+ return nil
+}
+
+func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "PING":
+ var source, destination string
+ if err := parseMessageParams(msg, &source); err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ destination = msg.Params[1]
+ }
+ hostname := dc.srv.Config().Hostname
+ if destination != "" && destination != hostname {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHSERVER,
+ Params: []string{dc.nick, destination, "No such server"},
+ }}
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "PONG",
+ Params: []string{hostname, source},
+ })
+ return nil
+ case "PONG":
+ if len(msg.Params) == 0 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ token := msg.Params[len(msg.Params)-1]
+ dc.handlePong(token)
+ case "USER":
+ return ircError{&irc.Message{
+ Command: irc.ERR_ALREADYREGISTERED,
+ Params: []string{dc.nick, "You may not reregister"},
+ }}
+ case "NICK":
+ var rawNick string
+ if err := parseMessageParams(msg, &rawNick); err != nil {
+ return err
+ }
+
+ nick := rawNick
+ var upstream *upstreamConn
+ if dc.upstream() == nil {
+ uc, unmarshaledNick, err := dc.unmarshalEntity(nick)
+ if err == nil { // NICK nick/network: NICK only on a specific upstream
+ upstream = uc
+ nick = unmarshaledNick
+ }
+ }
+
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, rawNick, "contains illegal characters"},
+ }}
+ }
+ if casemapASCII(nick) == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"},
+ }}
+ }
+
+ var err error
+ dc.forEachNetwork(func(n *network) {
+ if err != nil || (upstream != nil && upstream.network != n) {
+ return
+ }
+ n.Nick = nick
+ err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network)
+ })
+ if err != nil {
+ return err
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ if upstream != nil && upstream != uc {
+ return
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ })
+
+ if dc.upstream() == nil && upstream == nil && dc.nick != nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ dc.nick = nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+ case "SETNAME":
+ var realname string
+ if err := parseMessageParams(msg, &realname); err != nil {
+ return err
+ }
+
+ // If the client just resets to the default, just wipe the per-network
+ // preference
+ storeRealname := realname
+ if realname == dc.user.Realname {
+ storeRealname = ""
+ }
+
+ var storeErr error
+ var needUpdate []Network
+ dc.forEachNetwork(func(n *network) {
+ // We only need to call updateNetwork for upstreams that don't
+ // support setname
+ if uc := n.conn; uc != nil && uc.caps["setname"] {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+
+ n.Realname = storeRealname
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil {
+ dc.logger.Printf("failed to store network realname: %v", err)
+ storeErr = err
+ }
+ return
+ }
+
+ record := n.Network // copy network record because we'll mutate it
+ record.Realname = storeRealname
+ needUpdate = append(needUpdate, record)
+ })
+
+ // Walk the network list as a second step, because updateNetwork
+ // mutates the original list
+ for _, record := range needUpdate {
+ if _, err := dc.user.updateNetwork(ctx, &record); err != nil {
+ dc.logger.Printf("failed to update network realname: %v", err)
+ storeErr = err
+ }
+ }
+ if storeErr != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"},
+ }}
+ }
+
+ if dc.upstream() == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+ }
+ case "JOIN":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var keys []string
+ if len(msg.Params) > 1 {
+ keys = strings.Split(msg.Params[1], ",")
+ }
+
+ for i, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ var key string
+ if len(keys) > i {
+ key = keys[i]
+ }
+
+ if !uc.isChannel(upstreamName) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{name, "Not a channel name"},
+ })
+ continue
+ }
+
+ // Most servers ignore duplicate JOIN messages. We ignore them here
+ // because some clients automatically send JOIN messages in bulk
+ // when reconnecting to the bouncer. We don't want to flood the
+ // upstream connection with these.
+ if !uc.channels.Has(upstreamName) {
+ params := []string{upstreamName}
+ if key != "" {
+ params = append(params, key)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "JOIN",
+ Params: params,
+ })
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ // Don't clear the channel key if there's one set
+ // TODO: add a way to unset the channel key
+ if key != "" {
+ ch.Key = key
+ }
+ uc.network.attach(ctx, ch)
+ } else {
+ ch = &Channel{
+ Name: upstreamName,
+ Key: key,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ }
+ case "PART":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+
+ for _, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if strings.EqualFold(reason, "detach") {
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ uc.network.detach(ch)
+ } else {
+ ch = &Channel{
+ Name: name,
+ Detached: true,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ } else {
+ params := []string{upstreamName}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "PART",
+ Params: params,
+ })
+
+ if err := uc.network.deleteChannel(ctx, upstreamName); err != nil {
+ dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err)
+ }
+ }
+ }
+ case "KICK":
+ var channelStr, userStr string
+ if err := parseMessageParams(msg, &channelStr, &userStr); err != nil {
+ return err
+ }
+
+ channels := strings.Split(channelStr, ",")
+ users := strings.Split(userStr, ",")
+
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+
+ if len(channels) != 1 && len(channels) != len(users) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_BADCHANMASK,
+ Params: []string{dc.nick, channelStr, "Bad channel mask"},
+ }}
+ }
+
+ for i, user := range users {
+ var channel string
+ if len(channels) == 1 {
+ channel = channels[0]
+ } else {
+ channel = channels[i]
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ params := []string{upstreamChannel, upstreamUser}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "KICK",
+ Params: params,
+ })
+ }
+ case "MODE":
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+
+ var modeStr string
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ if casemapASCII(name) == dc.nickCM {
+ if modeStr != "" {
+ if uc := dc.upstream(); uc != nil {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: []string{uc.nick, modeStr},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_UMODEUNKNOWNFLAG,
+ Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"},
+ })
+ }
+ } else {
+ var userMode string
+ if uc := dc.upstream(); uc != nil {
+ userMode = string(uc.modes)
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + userMode},
+ })
+ }
+ return nil
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if !uc.isChannel(upstreamName) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERSDONTMATCH,
+ Params: []string{dc.nick, "Cannot change mode for other users"},
+ }}
+ }
+
+ if modeStr != "" {
+ params := []string{upstreamName, modeStr}
+ params = append(params, msg.Params[2:]...)
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: params,
+ })
+ } else {
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "No such channel"},
+ }}
+ }
+
+ if ch.modes == nil {
+ // we haven't received the initial RPL_CHANNELMODEIS yet
+ // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS
+ return nil
+ }
+
+ modeStr, modeParams := ch.modes.Format()
+ params := []string{dc.nick, name, modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ if ch.creationTime != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, name, ch.creationTime},
+ })
+ }
+ }
+ case "TOPIC":
+ var channel string
+ if err := parseMessageParams(msg, &channel); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ if len(msg.Params) > 1 { // setting topic
+ topic := msg.Params[1]
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "TOPIC",
+ Params: []string{upstreamName, topic},
+ })
+ } else { // getting topic
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, upstreamName, "No such channel"},
+ }}
+ }
+ sendTopic(dc, ch)
+ }
+ case "LIST":
+ network := dc.network
+ if network == nil && len(msg.Params) > 0 {
+ var err error
+ network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0])
+ if err != nil {
+ return err
+ }
+ }
+ if network == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"},
+ })
+ return nil
+ }
+
+ uc := network.conn
+ if uc == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Disconnected from upstream server"},
+ })
+ return nil
+ }
+
+ uc.enqueueCommand(dc, msg)
+ case "NAMES":
+ if len(msg.Params) == 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, "*", "End of /NAMES list"},
+ })
+ return nil
+ }
+
+ channels := strings.Split(msg.Params[0], ",")
+ for _, channel := range channels {
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(upstreamName)
+ if ch != nil {
+ sendNames(dc, ch)
+ } else {
+ // NAMES on a channel we have not joined, ask upstream
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NAMES",
+ Params: []string{upstreamName},
+ })
+ }
+ }
+ // For WHOX docs, see:
+ // - http://faerion.sourceforge.net/doc/irc/whox.var
+ // - https://github.com/quakenet/snircd/blob/master/doc/readme.who
+ // Note, many features aren't widely implemented, such as flags and mask2
+ case "WHO":
+ if len(msg.Params) == 0 {
+ // TODO: support WHO without parameters
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, "*", "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // Clients will use the first mask to match RPL_ENDOFWHO
+ endOfWhoToken := msg.Params[0]
+
+ // TODO: add support for WHOX mask2
+ mask := msg.Params[0]
+ var options string
+ if len(msg.Params) > 1 {
+ options = msg.Params[1]
+ }
+
+ optionsParts := strings.SplitN(options, "%", 2)
+ // TODO: add support for WHOX flags in optionsParts[0]
+ var fields, whoxToken string
+ if len(optionsParts) == 2 {
+ optionsParts := strings.SplitN(optionsParts[1], ",", 2)
+ fields = strings.ToLower(optionsParts[0])
+ if len(optionsParts) == 2 && strings.Contains(fields, "t") {
+ whoxToken = optionsParts[1]
+ }
+ }
+
+ // TODO: support mixed bouncer/upstream WHO queries
+ maskCM := casemapASCII(mask)
+ if dc.network == nil && maskCM == dc.nickCM {
+ // TODO: support AWAY (H/G) in self WHO reply
+ flags := "H"
+ if dc.user.Admin {
+ flags += "*"
+ }
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: dc.user.Username,
+ Hostname: dc.hostname,
+ Server: dc.srv.Config().Hostname,
+ Nickname: dc.nick,
+ Flags: flags,
+ Account: dc.user.Username,
+ Realname: dc.realname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+ if maskCM == serviceNickCM {
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: servicePrefix.User,
+ Hostname: servicePrefix.Host,
+ Server: dc.srv.Config().Hostname,
+ Nickname: serviceNick,
+ Flags: "H*",
+ Account: serviceNick,
+ Realname: serviceRealname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // TODO: properly support WHO masks
+ uc, upstreamMask, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ params := []string{upstreamMask}
+ if options != "" {
+ params = append(params, options)
+ }
+
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "WHO",
+ Params: params,
+ })
+ case "WHOIS":
+ if len(msg.Params) == 0 {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NONICKNAMEGIVEN,
+ Params: []string{dc.nick, "No nickname given"},
+ }}
+ }
+
+ var target, mask string
+ if len(msg.Params) == 1 {
+ target = ""
+ mask = msg.Params[0]
+ } else {
+ target = msg.Params[0]
+ mask = msg.Params[1]
+ }
+ // TODO: support multiple WHOIS users
+ if i := strings.IndexByte(mask, ','); i >= 0 {
+ mask = mask[:i]
+ }
+
+ if dc.network == nil && casemapASCII(mask) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"},
+ })
+ if dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, dc.nick, "is a bouncer administrator"},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, dc.nick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+ if casemapASCII(mask) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, serviceNick, "is the bouncer service"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, serviceNick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+
+ // TODO: support WHOIS masks
+ uc, upstreamNick, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ var params []string
+ if target != "" {
+ if target == mask { // WHOIS nick nick
+ params = []string{upstreamNick, upstreamNick}
+ } else {
+ params = []string{target, upstreamNick}
+ }
+ } else {
+ params = []string{upstreamNick}
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "WHOIS",
+ Params: params,
+ })
+ case "PRIVMSG", "NOTICE":
+ var targetsStr, text string
+ if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) {
+ // "$" means a server mask follows. If it's the bouncer's
+ // hostname, broadcast the message to all bouncer users.
+ if !dc.user.Admin {
+ return ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_BADMASK,
+ Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"},
+ }}
+ }
+
+ dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text)
+
+ broadcastTags := tags.Copy()
+ broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ broadcastMsg := &irc.Message{
+ Tags: broadcastTags,
+ Prefix: servicePrefix,
+ Command: msg.Command,
+ Params: []string{name, text},
+ }
+ dc.srv.forEachUser(func(u *user) {
+ u.events <- eventBroadcast{broadcastMsg}
+ })
+ continue
+ }
+
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ continue
+ }
+
+ if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM {
+ if dc.caps["echo-message"] {
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ dc.SendMessage(&irc.Message{
+ Tags: echoTags,
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ }
+ handleServicePRIVMSG(ctx, dc, text)
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" {
+ dc.handleNickServPRIVMSG(ctx, uc, text)
+ }
+
+ unmarshaledText := text
+ if uc.isChannel(upstreamName) {
+ unmarshaledText = dc.unmarshalText(uc, text)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: msg.Command,
+ Params: []string{upstreamName, unmarshaledText},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: msg.Command,
+ Params: []string{upstreamName, text},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "TAGMSG":
+ var targetsStr string
+ if err := parseMessageParams(msg, &targetsStr); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: "TAGMSG",
+ Params: []string{name},
+ })
+ continue
+ }
+
+ if casemapASCII(name) == serviceNickCM {
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+ if _, ok := uc.caps["message-tags"]; !ok {
+ continue
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "INVITE":
+ var user, channel string
+ if err := parseMessageParams(msg, &user, &channel); err != nil {
+ return err
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "INVITE",
+ Params: []string{upstreamUser, upstreamChannel},
+ })
+ case "AUTHENTICATE":
+ // Post-connection-registration AUTHENTICATE is unsupported in
+ // multi-upstream mode, or if the upstream doesn't support SASL
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Upstream network authentication not supported"},
+ }}
+ }
+
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ }
+
+ if credentials != nil {
+ if uc.saslClient != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Another authentication attempt is already in progress"},
+ })
+ return nil
+ }
+
+ uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername)
+ uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword)
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"PLAIN"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ // Check number of params here, since we'll use that to save the
+ // credentials on command success
+ if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["draft/account-registration"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"},
+ }}
+ }
+
+ uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0])
+ uc.enqueueCommand(dc, msg)
+ case "MONITOR":
+ // MONITOR is unsupported in multi-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "+", "-":
+ var targets string
+ if err := parseMessageParams(msg, nil, &targets); err != nil {
+ return err
+ }
+ for _, target := range strings.Split(targets, ",") {
+ if subcommand == "+" {
+ // Hard limit, just to avoid having downstreams fill our map
+ if len(dc.monitored.innerMap) >= 1000 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_MONLISTFULL,
+ Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"},
+ })
+ continue
+ }
+
+ dc.monitored.SetValue(target, nil)
+
+ if uc.monitored.Has(target) {
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ } else {
+ dc.monitored.Delete(target)
+ }
+ }
+ uc.updateMonitor()
+ case "C": // clear
+ dc.monitored = newCasemapMap(0)
+ uc.updateMonitor()
+ case "L": // list
+ // TODO: be less lazy and pack the list
+ for _, entry := range dc.monitored.innerMap {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MONLIST,
+ Params: []string{dc.nick, entry.originalKey},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFMONLIST,
+ Params: []string{dc.nick, "End of MONITOR list"},
+ })
+ case "S": // status
+ // TODO: be less lazy and pack the lists
+ for _, entry := range dc.monitored.innerMap {
+ target := entry.originalKey
+
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ case "CHATHISTORY":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+ var target, limitStr string
+ var boundsStr [2]string
+ switch subcommand {
+ case "AFTER", "BEFORE", "LATEST":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil {
+ return err
+ }
+ case "BETWEEN":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ case "TARGETS":
+ if dc.network == nil {
+ // Either an unbound bouncer network, in which case we should return no targets,
+ // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet.
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+ if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ default:
+ // TODO: support AROUND
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"},
+ }}
+ }
+
+ // We don't save history for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+
+ store, ok := dc.user.msgStore.(chatHistoryMessageStore)
+ if !ok {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{dc.nick, "CHATHISTORY", "Unknown command"},
+ }}
+ }
+
+ network, entity, err := dc.unmarshalEntityNetwork(target)
+ if err != nil {
+ return err
+ }
+ entity = network.casemap(entity)
+
+ // TODO: support msgid criteria
+ var bounds [2]time.Time
+ bounds[0] = parseChatHistoryBound(boundsStr[0])
+ if subcommand == "LATEST" && boundsStr[0] == "*" {
+ bounds[0] = time.Now()
+ } else if bounds[0].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"},
+ }}
+ }
+
+ if boundsStr[1] != "" {
+ bounds[1] = parseChatHistoryBound(boundsStr[1])
+ if bounds[1].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"},
+ }}
+ }
+ }
+
+ limit, err := strconv.Atoi(limitStr)
+ if err != nil || limit < 0 || limit > chatHistoryLimit {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"},
+ }}
+ }
+
+ eventPlayback := dc.caps["draft/event-playback"]
+
+ var history []*irc.Message
+ switch subcommand {
+ case "BEFORE", "LATEST":
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback)
+ case "AFTER":
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback)
+ case "BETWEEN":
+ if bounds[0].Before(bounds[1]) {
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ } else {
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ }
+ case "TARGETS":
+ // TODO: support TARGETS in multi-upstream mode
+ targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback)
+ if err != nil {
+ dc.logger.Printf("failed fetching targets for chathistory: %v", err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"},
+ }}
+ }
+
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {
+ for _, target := range targets {
+ if ch := network.channels.Value(target.Name); ch != nil && ch.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "CHATHISTORY",
+ Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)},
+ })
+ }
+ })
+
+ return nil
+ }
+ if err != nil {
+ dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err)
+ return newChatHistoryError(subcommand, target)
+ }
+
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, network))
+ }
+ })
+ case "READ":
+ var target, criteria string
+ if err := parseMessageParams(msg, &target); err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"},
+ }}
+ }
+ if len(msg.Params) > 1 {
+ criteria = msg.Params[1]
+ }
+
+ // We don't save read receipts for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{target, "*"},
+ })
+ return nil
+ }
+
+ uc, entity, err := dc.unmarshalEntity(target)
+ if err != nil {
+ return err
+ }
+ entityCM := uc.network.casemap(entity)
+
+ r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ } else if r == nil {
+ r = &ReadReceipt{
+ Target: entityCM,
+ }
+ }
+
+ broadcast := false
+ if len(criteria) > 0 {
+ // TODO: support msgid criteria
+ criteriaParts := strings.SplitN(criteria, "=", 2)
+ if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"},
+ }}
+ }
+
+ timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1])
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"},
+ }}
+ }
+ now := time.Now()
+ if timestamp.After(now) {
+ timestamp = now
+ }
+ if r.Timestamp.Before(timestamp) {
+ r.Timestamp = timestamp
+ if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil {
+ dc.logger.Printf("failed to store receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ }
+ broadcast = true
+ }
+ }
+
+ timestampStr := "*"
+ if !r.Timestamp.IsZero() {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ uc.forEachDownstream(func(d *downstreamConn) {
+ if broadcast || dc.id == d.id {
+ d.SendMessage(&irc.Message{
+ Prefix: d.prefix(),
+ Command: "READ",
+ Params: []string{d.marshalEntity(uc.network, entity), timestampStr},
+ })
+ }
+ })
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"},
+ }}
+ case "LISTNETWORKS":
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ case "ADDNETWORK":
+ var attrsStr string
+ if err := parseMessageParams(msg, nil, &attrsStr); err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ record := &Network{Nick: dc.nick, Enabled: true}
+ if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)},
+ })
+ case "CHANGENETWORK":
+ var idStr, attrsStr string
+ if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ _, err = dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"CHANGENETWORK", idStr},
+ })
+ case "DELNETWORK":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"DELNETWORK", idStr},
+ })
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"},
+ }}
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+
+ // Only forward unknown commands in single-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, msg)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) {
+ username, password, ok := parseNickServCredentials(text, uc.nick)
+ if ok {
+ uc.network.autoSaveSASLPlain(ctx, username, password)
+ }
+}
+
+func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
+ fields := strings.Fields(text)
+ if len(fields) < 2 {
+ return "", "", false
+ }
+ cmd := strings.ToUpper(fields[0])
+ params := fields[1:]
+ switch cmd {
+ case "REGISTER":
+ username = nick
+ password = params[0]
+ case "IDENTIFY":
+ if len(params) == 1 {
+ username = nick
+ password = params[0]
+ } else {
+ username = params[0]
+ password = params[1]
+ }
+ case "SET":
+ if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") {
+ username = nick
+ password = params[1]
+ }
+ default:
+ return "", "", false
+ }
+ return username, password, true
+}
--- /dev/null
+module marisa.chaotic.ninja/suika
+
+go 1.20
+
+require (
+ git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99
+ git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9
+ github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead
+ github.com/lib/pq v1.10.7
+ golang.org/x/crypto v0.7.0
+ golang.org/x/term v0.6.0
+ golang.org/x/time v0.3.0
+ gopkg.in/irc.v3 v3.1.4
+ modernc.org/sqlite v1.21.0
+)
+
+require (
+ github.com/dustin/go-humanize v1.0.0 // indirect
+ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
+ github.com/google/uuid v1.3.0 // indirect
+ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
+ github.com/mattn/go-isatty v0.0.16 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/stretchr/testify v1.8.0 // indirect
+ golang.org/x/mod v0.3.0 // indirect
+ golang.org/x/sys v0.6.0 // indirect
+ golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect
+ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+ gopkg.in/yaml.v2 v2.4.0 // indirect
+ lukechampine.com/uint128 v1.2.0 // indirect
+ modernc.org/cc/v3 v3.40.0 // indirect
+ modernc.org/ccgo/v3 v3.16.13 // indirect
+ modernc.org/libc v1.22.3 // indirect
+ modernc.org/mathutil v1.5.0 // indirect
+ modernc.org/memory v1.5.0 // indirect
+ modernc.org/opt v0.1.3 // indirect
+ modernc.org/strutil v1.1.3 // indirect
+ modernc.org/token v1.0.1 // indirect
+)
--- /dev/null
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao=
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U=
+git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
+github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
+github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
+github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
+github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
+golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
+golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
+golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
+golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
+golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc=
+gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI=
+lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
+modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw=
+modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0=
+modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw=
+modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY=
+modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk=
+modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM=
+modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY=
+modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw=
+modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
+modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
+modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
+modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
+modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
+modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
+modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow=
+modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI=
+modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
+modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw=
+modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws=
+modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg=
+modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
+modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE=
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+ "unicode"
+ "unicode/utf8"
+
+ "gopkg.in/irc.v3"
+)
+
+const (
+ rpl_statsping = "246"
+ rpl_localusers = "265"
+ rpl_globalusers = "266"
+ rpl_creationtime = "329"
+ rpl_topicwhotime = "333"
+ rpl_whospcrpl = "354"
+ rpl_whoisaccount = "330"
+ err_invalidcapcmd = "410"
+)
+
+const (
+ maxMessageLength = 512
+ maxMessageParams = 15
+ maxSASLLength = 400
+)
+
+// The server-time layout, as defined in the IRCv3 spec.
+const serverTimeLayout = "2006-01-02T15:04:05.000Z"
+
+func formatServerTime(t time.Time) string {
+ return t.UTC().Format(serverTimeLayout)
+}
+
+type userModes string
+
+func (ms userModes) Has(c byte) bool {
+ return strings.IndexByte(string(ms), c) >= 0
+}
+
+func (ms *userModes) Add(c byte) {
+ if !ms.Has(c) {
+ *ms += userModes(c)
+ }
+}
+
+func (ms *userModes) Del(c byte) {
+ i := strings.IndexByte(string(*ms), c)
+ if i >= 0 {
+ *ms = (*ms)[:i] + (*ms)[i+1:]
+ }
+}
+
+func (ms *userModes) Apply(s string) error {
+ var plusMinus byte
+ for i := 0; i < len(s); i++ {
+ switch c := s[i]; c {
+ case '+', '-':
+ plusMinus = c
+ default:
+ switch plusMinus {
+ case '+':
+ ms.Add(c)
+ case '-':
+ ms.Del(c)
+ default:
+ return fmt.Errorf("malformed modestring %q: missing plus/minus", s)
+ }
+ }
+ }
+ return nil
+}
+
+type channelModeType byte
+
+// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message
+const (
+ // modes that add or remove an address to or from a list
+ modeTypeA channelModeType = iota
+ // modes that change a setting on a channel, and must always have a parameter
+ modeTypeB
+ // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset
+ modeTypeC
+ // modes that change a setting on a channel, and must not have a parameter
+ modeTypeD
+)
+
+var stdChannelModes = map[byte]channelModeType{
+ 'b': modeTypeA, // ban list
+ 'e': modeTypeA, // ban exception list
+ 'I': modeTypeA, // invite exception list
+ 'k': modeTypeB, // channel key
+ 'l': modeTypeC, // channel user limit
+ 'i': modeTypeD, // channel is invite-only
+ 'm': modeTypeD, // channel is moderated
+ 'n': modeTypeD, // channel has no external messages
+ 's': modeTypeD, // channel is secret
+ 't': modeTypeD, // channel has protected topic
+}
+
+type channelModes map[byte]string
+
+// applyChannelModes parses a mode string and mode arguments from a MODE message,
+// and applies the corresponding channel mode and user membership changes on that channel.
+//
+// If ch.modes is nil, channel modes are not updated.
+//
+// needMarshaling is a list of indexes of mode arguments that represent entities
+// that must be marshaled when sent downstream.
+func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) {
+ needMarshaling = make(map[int]struct{}, len(arguments))
+ nextArgument := 0
+ var plusMinus byte
+outer:
+ for i := 0; i < len(modeStr); i++ {
+ mode := modeStr[i]
+ if mode == '+' || mode == '-' {
+ plusMinus = mode
+ continue
+ }
+ if plusMinus != '+' && plusMinus != '-' {
+ return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr)
+ }
+
+ for _, membership := range ch.conn.availableMemberships {
+ if membership.Mode == mode {
+ if nextArgument >= len(arguments) {
+ return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
+ }
+ member := arguments[nextArgument]
+ m := ch.Members.Value(member)
+ if m != nil {
+ if plusMinus == '+' {
+ m.Add(ch.conn.availableMemberships, membership)
+ } else {
+ // TODO: for upstreams without multi-prefix, query the user modes again
+ m.Remove(membership)
+ }
+ }
+ needMarshaling[nextArgument] = struct{}{}
+ nextArgument++
+ continue outer
+ }
+ }
+
+ mt, ok := ch.conn.availableChannelModes[mode]
+ if !ok {
+ continue
+ }
+ if mt == modeTypeA {
+ nextArgument++
+ } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') {
+ if plusMinus == '+' {
+ var argument string
+ // some sentitive arguments (such as channel keys) can be omitted for privacy
+ // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages)
+ if nextArgument < len(arguments) {
+ argument = arguments[nextArgument]
+ }
+ if ch.modes != nil {
+ ch.modes[mode] = argument
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ nextArgument++
+ } else if mt == modeTypeC || mt == modeTypeD {
+ if plusMinus == '+' {
+ if ch.modes != nil {
+ ch.modes[mode] = ""
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ }
+ }
+ return needMarshaling, nil
+}
+
+func (cm channelModes) Format() (modeString string, parameters []string) {
+ var modesWithValues strings.Builder
+ var modesWithoutValues strings.Builder
+ parameters = make([]string, 0, 16)
+ for mode, value := range cm {
+ if value != "" {
+ modesWithValues.WriteString(string(mode))
+ parameters = append(parameters, value)
+ } else {
+ modesWithoutValues.WriteString(string(mode))
+ }
+ }
+ modeString = "+" + modesWithValues.String() + modesWithoutValues.String()
+ return
+}
+
+const stdChannelTypes = "#&+!"
+
+type channelStatus byte
+
+const (
+ channelPublic channelStatus = '='
+ channelSecret channelStatus = '@'
+ channelPrivate channelStatus = '*'
+)
+
+func parseChannelStatus(s string) (channelStatus, error) {
+ if len(s) > 1 {
+ return 0, fmt.Errorf("invalid channel status %q: more than one character", s)
+ }
+ switch cs := channelStatus(s[0]); cs {
+ case channelPublic, channelSecret, channelPrivate:
+ return cs, nil
+ default:
+ return 0, fmt.Errorf("invalid channel status %q: unknown status", s)
+ }
+}
+
+type membership struct {
+ Mode byte
+ Prefix byte
+}
+
+var stdMemberships = []membership{
+ {'q', '~'}, // founder
+ {'a', '&'}, // protected
+ {'o', '@'}, // operator
+ {'h', '%'}, // halfop
+ {'v', '+'}, // voice
+}
+
+// memberships always sorted by descending membership rank
+type memberships []membership
+
+func (m *memberships) Add(availableMemberships []membership, newMembership membership) {
+ l := *m
+ i := 0
+ for _, availableMembership := range availableMemberships {
+ if i >= len(l) {
+ break
+ }
+ if l[i] == availableMembership {
+ if availableMembership == newMembership {
+ // we already have this membership
+ return
+ }
+ i++
+ continue
+ }
+ if availableMembership == newMembership {
+ break
+ }
+ }
+ // insert newMembership at i
+ l = append(l, membership{})
+ copy(l[i+1:], l[i:])
+ l[i] = newMembership
+ *m = l
+}
+
+func (m *memberships) Remove(oldMembership membership) {
+ l := *m
+ for i, currentMembership := range l {
+ if currentMembership == oldMembership {
+ *m = append(l[:i], l[i+1:]...)
+ return
+ }
+ }
+}
+
+func (m memberships) Format(dc *downstreamConn) string {
+ if !dc.caps["multi-prefix"] {
+ if len(m) == 0 {
+ return ""
+ }
+ return string(m[0].Prefix)
+ }
+ prefixes := make([]byte, len(m))
+ for i, membership := range m {
+ prefixes[i] = membership.Prefix
+ }
+ return string(prefixes)
+}
+
+func parseMessageParams(msg *irc.Message, out ...*string) error {
+ if len(msg.Params) < len(out) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ for i := range out {
+ if out[i] != nil {
+ *out[i] = msg.Params[i]
+ }
+ }
+ return nil
+}
+
+func copyClientTags(tags irc.Tags) irc.Tags {
+ t := make(irc.Tags, len(tags))
+ for k, v := range tags {
+ if strings.HasPrefix(k, "+") {
+ t[k] = v
+ }
+ }
+ return t
+}
+
+type batch struct {
+ Type string
+ Params []string
+ Outer *batch // if not-nil, this batch is nested in Outer
+ Label string
+}
+
+func join(channels, keys []string) []*irc.Message {
+ // Put channels with a key first
+ js := joinSorter{channels, keys}
+ sort.Sort(&js)
+
+ // Two spaces because there are three words (JOIN, channels and keys)
+ maxLength := maxMessageLength - (len("JOIN") + 2)
+
+ var msgs []*irc.Message
+ var channelsBuf, keysBuf strings.Builder
+ for i, channel := range channels {
+ key := keys[i]
+
+ n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel)
+ if key != "" {
+ n += 1 + len(key)
+ }
+
+ if channelsBuf.Len() > 0 && n > maxLength {
+ // No room for the new channel in this message
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ channelsBuf.Reset()
+ keysBuf.Reset()
+ }
+
+ if channelsBuf.Len() > 0 {
+ channelsBuf.WriteByte(',')
+ }
+ channelsBuf.WriteString(channel)
+ if key != "" {
+ if keysBuf.Len() > 0 {
+ keysBuf.WriteByte(',')
+ }
+ keysBuf.WriteString(key)
+ }
+ }
+ if channelsBuf.Len() > 0 {
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ }
+
+ return msgs
+}
+
+func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message {
+ maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text
+
+ var msgs []*irc.Message
+ for len(tokens) > 0 {
+ var msgTokens []string
+ if len(tokens) > maxTokens {
+ msgTokens = tokens[:maxTokens]
+ tokens = tokens[maxTokens:]
+ } else {
+ msgTokens = tokens
+ tokens = nil
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ISUPPORT,
+ Params: append(append([]string{nick}, msgTokens...), "are supported"),
+ })
+ }
+
+ return msgs
+}
+
+func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message {
+ var msgs []*irc.Message
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTDSTART,
+ Params: []string{nick, fmt.Sprintf("- Message of the Day -")},
+ })
+
+ for _, l := range strings.Split(motd, "\n") {
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTD,
+ Params: []string{nick, l},
+ })
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ENDOFMOTD,
+ Params: []string{nick, "End of /MOTD command."},
+ })
+
+ return msgs
+}
+
+func generateMonitor(subcmd string, targets []string) []*irc.Message {
+ maxLength := maxMessageLength - len("MONITOR "+subcmd+" ")
+
+ var msgs []*irc.Message
+ var buf []string
+ n := 0
+ for _, target := range targets {
+ if n+len(target)+1 > maxLength {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ buf = buf[:0]
+ n = 0
+ }
+
+ buf = append(buf, target)
+ n += len(target) + 1
+ }
+
+ if len(buf) > 0 {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ }
+
+ return msgs
+}
+
+type joinSorter struct {
+ channels []string
+ keys []string
+}
+
+func (js *joinSorter) Len() int {
+ return len(js.channels)
+}
+
+func (js *joinSorter) Less(i, j int) bool {
+ if (js.keys[i] != "") != (js.keys[j] != "") {
+ // Only one of the channels has a key
+ return js.keys[i] != ""
+ }
+ return js.channels[i] < js.channels[j]
+}
+
+func (js *joinSorter) Swap(i, j int) {
+ js.channels[i], js.channels[j] = js.channels[j], js.channels[i]
+ js.keys[i], js.keys[j] = js.keys[j], js.keys[i]
+}
+
+// parseCTCPMessage parses a CTCP message. CTCP is defined in
+// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02
+func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
+ if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 {
+ return "", "", false
+ }
+ text := msg.Params[1]
+
+ if !strings.HasPrefix(text, "\x01") {
+ return "", "", false
+ }
+ text = strings.Trim(text, "\x01")
+
+ words := strings.SplitN(text, " ", 2)
+ cmd = strings.ToUpper(words[0])
+ if len(words) > 1 {
+ params = words[1]
+ }
+
+ return cmd, params, true
+}
+
+type casemapping func(string) string
+
+func casemapNone(name string) string {
+ return name
+}
+
+// CasemapASCII of name is the canonical representation of name according to the
+// ascii casemapping.
+func casemapASCII(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459 of name is the canonical representation of name according to the
+// rfc1459 casemapping.
+func casemapRFC1459(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ } else if r == '~' {
+ nameBytes[i] = '^'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459Strict of name is the canonical representation of name
+// according to the rfc1459-strict casemapping.
+func casemapRFC1459Strict(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ }
+ }
+ return string(nameBytes)
+}
+
+func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
+ switch tokenValue {
+ case "ascii":
+ casemap = casemapASCII
+ case "rfc1459":
+ casemap = casemapRFC1459
+ case "rfc1459-strict":
+ casemap = casemapRFC1459Strict
+ default:
+ return nil, false
+ }
+ return casemap, true
+}
+
+func partialCasemap(higher casemapping, name string) string {
+ nameFullyCM := []byte(higher(name))
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') {
+ nameBytes[i] = nameFullyCM[i]
+ }
+ }
+ return string(nameBytes)
+}
+
+type casemapMap struct {
+ innerMap map[string]casemapEntry
+ casemap casemapping
+}
+
+type casemapEntry struct {
+ originalKey string
+ value interface{}
+}
+
+func newCasemapMap(size int) casemapMap {
+ return casemapMap{
+ innerMap: make(map[string]casemapEntry, size),
+ casemap: casemapNone,
+ }
+}
+
+func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return "", false
+ }
+ return entry.originalKey, true
+}
+
+func (cm *casemapMap) Has(name string) bool {
+ _, ok := cm.innerMap[cm.casemap(name)]
+ return ok
+}
+
+func (cm *casemapMap) Len() int {
+ return len(cm.innerMap)
+}
+
+func (cm *casemapMap) SetValue(name string, value interface{}) {
+ nameCM := cm.casemap(name)
+ entry, ok := cm.innerMap[nameCM]
+ if !ok {
+ cm.innerMap[nameCM] = casemapEntry{
+ originalKey: name,
+ value: value,
+ }
+ return
+ }
+ entry.value = value
+ cm.innerMap[nameCM] = entry
+}
+
+func (cm *casemapMap) Delete(name string) {
+ delete(cm.innerMap, cm.casemap(name))
+}
+
+func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
+ cm.casemap = newCasemap
+ newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
+ for _, entry := range cm.innerMap {
+ newInnerMap[cm.casemap(entry.originalKey)] = entry
+ }
+ cm.innerMap = newInnerMap
+}
+
+type upstreamChannelCasemapMap struct{ casemapMap }
+
+func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*upstreamChannel)
+}
+
+type channelCasemapMap struct{ casemapMap }
+
+func (cm *channelCasemapMap) Value(name string) *Channel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*Channel)
+}
+
+type membershipsCasemapMap struct{ casemapMap }
+
+func (cm *membershipsCasemapMap) Value(name string) *memberships {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*memberships)
+}
+
+type deliveredCasemapMap struct{ casemapMap }
+
+func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(deliveredClientMap)
+}
+
+type monitorCasemapMap struct{ casemapMap }
+
+func (cm *monitorCasemapMap) Value(name string) (online bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return false
+ }
+ return entry.value.(bool)
+}
+
+func isWordBoundary(r rune) bool {
+ switch r {
+ case '-', '_', '|': // inspired from weechat.look.highlight_regex
+ return false
+ default:
+ return !unicode.IsLetter(r) && !unicode.IsNumber(r)
+ }
+}
+
+func isHighlight(text, nick string) bool {
+ for {
+ i := strings.Index(text, nick)
+ if i < 0 {
+ return false
+ }
+
+ left, _ := utf8.DecodeLastRuneInString(text[:i])
+ right, _ := utf8.DecodeRuneInString(text[i+len(nick):])
+ if isWordBoundary(left) && isWordBoundary(right) {
+ return true
+ }
+
+ text = text[i+len(nick):]
+ }
+}
+
+// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound.
+// The zero time is returned on error.
+func parseChatHistoryBound(param string) time.Time {
+ parts := strings.SplitN(param, "=", 2)
+ if len(parts) != 2 {
+ return time.Time{}
+ }
+ switch parts[0] {
+ case "timestamp":
+ timestamp, err := time.Parse(serverTimeLayout, parts[1])
+ if err != nil {
+ return time.Time{}
+ }
+ return timestamp
+ default:
+ return time.Time{}
+ }
+}
+
+// whoxFields is the list of all WHOX field letters, by order of appearance in
+// RPL_WHOSPCRPL messages.
+var whoxFields = []byte("tcuihsnfdlaor")
+
+type whoxInfo struct {
+ Token string
+ Username string
+ Hostname string
+ Server string
+ Nickname string
+ Flags string
+ Account string
+ Realname string
+}
+
+func (info *whoxInfo) get(field byte) string {
+ switch field {
+ case 't':
+ return info.Token
+ case 'c':
+ return "*"
+ case 'u':
+ return info.Username
+ case 'i':
+ return "255.255.255.255"
+ case 'h':
+ return info.Hostname
+ case 's':
+ return info.Server
+ case 'n':
+ return info.Nickname
+ case 'f':
+ return info.Flags
+ case 'd':
+ return "0"
+ case 'l': // idle time
+ return "0"
+ case 'a':
+ account := "0" // WHOX uses "0" to mean "no account"
+ if info.Account != "" && info.Account != "*" {
+ account = info.Account
+ }
+ return account
+ case 'o':
+ return "0"
+ case 'r':
+ return info.Realname
+ }
+ return ""
+}
+
+func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message {
+ if fields == "" {
+ return &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname},
+ }
+ }
+
+ fieldSet := make(map[byte]bool)
+ for i := 0; i < len(fields); i++ {
+ fieldSet[fields[i]] = true
+ }
+
+ var values []string
+ for _, field := range whoxFields {
+ if !fieldSet[field] {
+ continue
+ }
+ values = append(values, info.get(field))
+ }
+
+ return &irc.Message{
+ Prefix: prefix,
+ Command: rpl_whospcrpl,
+ Params: append([]string{nick}, values...),
+ }
+}
+
+var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C")
+
+func encodeISUPPORT(s string) string {
+ return isupportEncoder.Replace(s)
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func TestIsHighlight(t *testing.T) {
+ nick := "SojuUser"
+ testCases := []struct {
+ name string
+ text string
+ hl bool
+ }{
+ {"noContains", "hi there Soju User!", false},
+ {"middle", "hi there SojuUser!", true},
+ {"start", "SojuUser: how are you doing?", true},
+ {"end", "maybe ask SojuUser", true},
+ {"inWord", "but OtherSojuUserSan is a different nick", false},
+ {"startWord", "and OtherSojuUser is another different nick", false},
+ {"endWord", "and SojuUserSan is yet a different nick", false},
+ {"underscore", "and SojuUser_san has nothing to do with me", false},
+ {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false},
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ hl := isHighlight(tc.text, nick)
+ if hl != tc.hl {
+ t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl)
+ }
+ })
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+// messageStore is a per-user store for IRC messages.
+type messageStore interface {
+ Close() error
+ // LastMsgID queries the last message ID for the given network, entity and
+ // date. The message ID returned may not refer to a valid message, but can be
+ // used in history queries.
+ LastMsgID(network *Network, entity string, t time.Time) (string, error)
+ // LoadLatestID queries the latest non-event messages for the given network,
+ // entity and date, up to a count of limit messages, sorted from oldest to newest.
+ LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
+ Append(network *Network, entity string, msg *irc.Message) (id string, err error)
+}
+
+type chatHistoryTarget struct {
+ Name string
+ LatestMessage time.Time
+}
+
+// chatHistoryMessageStore is a message store that supports chat history
+// operations.
+type chatHistoryMessageStore interface {
+ messageStore
+
+ // ListTargets lists channels and nicknames by time of the latest message.
+ // It returns up to limit targets, starting from start and ending on end,
+ // both excluded. end may be before or after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
+ // LoadBeforeTime loads up to limit messages before start down to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is before start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+ // LoadBeforeTime loads up to limit messages after start up to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+}
+
+type msgIDType uint
+
+const (
+ msgIDNone msgIDType = iota
+ msgIDMemory
+ msgIDFS
+)
+
+const msgIDVersion uint = 0
+
+type msgIDHeader struct {
+ Version uint
+ Network bare.Int
+ Target string
+ Type msgIDType
+}
+
+type msgIDBody interface {
+ msgIDType() msgIDType
+}
+
+func formatMsgID(netID int64, target string, body msgIDBody) string {
+ var buf bytes.Buffer
+ w := bare.NewWriter(&buf)
+
+ header := msgIDHeader{
+ Version: msgIDVersion,
+ Network: bare.Int(netID),
+ Target: target,
+ Type: body.msgIDType(),
+ }
+ if err := bare.MarshalWriter(w, &header); err != nil {
+ panic(err)
+ }
+ if err := bare.MarshalWriter(w, body); err != nil {
+ panic(err)
+ }
+ return base64.RawURLEncoding.EncodeToString(buf.Bytes())
+}
+
+func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) {
+ b, err := base64.RawURLEncoding.DecodeString(s)
+ if err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ r := bare.NewReader(bytes.NewReader(b))
+
+ var header msgIDHeader
+ if err := bare.UnmarshalBareReader(r, &header); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ if header.Version != msgIDVersion {
+ return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion)
+ }
+
+ if body != nil {
+ typ := body.msgIDType()
+ if header.Type != typ {
+ return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ)
+ }
+
+ if err := bare.UnmarshalBareReader(r, body); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+ }
+
+ return int64(header.Network), header.Target, nil
+}
--- /dev/null
+package suika
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ fsMessageStoreMaxFiles = 20
+ fsMessageStoreMaxTries = 100
+)
+
+func escapeFilename(unsafe string) (safe string) {
+ if unsafe == "." {
+ return "-"
+ } else if unsafe == ".." {
+ return "--"
+ } else {
+ return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe)
+ }
+}
+
+type date struct {
+ Year, Month, Day int
+}
+
+func newDate(t time.Time) date {
+ year, month, day := t.Date()
+ return date{year, int(month), day}
+}
+
+func (d date) Time() time.Time {
+ return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local)
+}
+
+type fsMsgID struct {
+ Date date
+ Offset bare.Int
+}
+
+func (fsMsgID) msgIDType() msgIDType {
+ return msgIDFS
+}
+
+func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) {
+ var id fsMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", time.Time{}, 0, err
+ }
+ return netID, entity, id.Date.Time(), int64(id.Offset), nil
+}
+
+func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string {
+ id := fsMsgID{
+ Date: newDate(t),
+ Offset: bare.Int(offset),
+ }
+ return formatMsgID(netID, entity, &id)
+}
+
+type fsMessageStoreFile struct {
+ *os.File
+ lastUse time.Time
+}
+
+// fsMessageStore is a per-user on-disk store for IRC messages.
+//
+// It mimicks the ZNC log layout and format. See the ZNC source:
+// https://github.com/znc/znc/blob/master/modules/log.cpp
+type fsMessageStore struct {
+ root string
+ user *User
+
+ // Write-only files used by Append
+ files map[string]*fsMessageStoreFile // indexed by entity
+}
+
+var _ messageStore = (*fsMessageStore)(nil)
+var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
+
+func newFSMessageStore(root string, user *User) *fsMessageStore {
+ return &fsMessageStore{
+ root: filepath.Join(root, escapeFilename(user.Username)),
+ user: user,
+ files: make(map[string]*fsMessageStoreFile),
+ }
+}
+
+func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
+ year, month, day := t.Date()
+ filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
+ return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
+}
+
+// nextMsgID queries the message ID for the next message to be written to f.
+func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
+ offset, err := f.Seek(0, io.SeekEnd)
+ if err != nil {
+ return "", fmt.Errorf("failed to query next FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, offset), nil
+}
+
+func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ p := ms.logPath(network, entity, t)
+ fi, err := os.Stat(p)
+ if os.IsNotExist(err) {
+ return formatFSMsgID(network.ID, entity, t, -1), nil
+ } else if err != nil {
+ return "", fmt.Errorf("failed to query last FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
+}
+
+func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ s := formatMessage(msg)
+ if s == "" {
+ return "", nil
+ }
+
+ var t time.Time
+ if tag, ok := msg.Tags["time"]; ok {
+ var err error
+ t, err = time.Parse(serverTimeLayout, string(tag))
+ if err != nil {
+ return "", fmt.Errorf("failed to parse message time tag: %v", err)
+ }
+ t = t.In(time.Local)
+ } else {
+ t = time.Now()
+ }
+
+ f := ms.files[entity]
+
+ // TODO: handle non-monotonic clock behaviour
+ path := ms.logPath(network, entity, t)
+ if f == nil || f.Name() != path {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0750); err != nil {
+ return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err)
+ }
+
+ ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640)
+ if err != nil {
+ return "", fmt.Errorf("failed to open message log file %q: %v", path, err)
+ }
+
+ if f != nil {
+ f.Close()
+ }
+ f = &fsMessageStoreFile{File: ff}
+ ms.files[entity] = f
+ }
+
+ f.lastUse = time.Now()
+
+ if len(ms.files) > fsMessageStoreMaxFiles {
+ entities := make([]string, 0, len(ms.files))
+ for name := range ms.files {
+ entities = append(entities, name)
+ }
+ sort.Slice(entities, func(i, j int) bool {
+ a, b := entities[i], entities[j]
+ return ms.files[a].lastUse.Before(ms.files[b].lastUse)
+ })
+ entities = entities[0 : len(entities)-fsMessageStoreMaxFiles]
+ for _, name := range entities {
+ ms.files[name].Close()
+ delete(ms.files, name)
+ }
+ }
+
+ msgID, err := nextFSMsgID(network, entity, t, f.File)
+ if err != nil {
+ return "", fmt.Errorf("failed to generate message ID: %v", err)
+ }
+
+ _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
+ if err != nil {
+ return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err)
+ }
+
+ return msgID, nil
+}
+
+func (ms *fsMessageStore) Close() error {
+ var closeErr error
+ for _, f := range ms.files {
+ if err := f.Close(); err != nil {
+ closeErr = fmt.Errorf("failed to close message store: %v", err)
+ }
+ }
+ return closeErr
+}
+
+// formatMessage formats a message log line. It assumes a well-formed IRC
+// message.
+func formatMessage(msg *irc.Message) string {
+ switch strings.ToUpper(msg.Command) {
+ case "NICK":
+ return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0])
+ case "JOIN":
+ return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host)
+ case "PART":
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+ return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "KICK":
+ nick := msg.Params[1]
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+ return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason)
+ case "QUIT":
+ var reason string
+ if len(msg.Params) > 0 {
+ reason = msg.Params[0]
+ }
+ return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "TOPIC":
+ var topic string
+ if len(msg.Params) > 1 {
+ topic = msg.Params[1]
+ }
+ return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic)
+ case "MODE":
+ return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " "))
+ case "NOTICE":
+ return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1])
+ case "PRIVMSG":
+ if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" {
+ return fmt.Sprintf("* %s %s", msg.Prefix.Name, params)
+ } else {
+ return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1])
+ }
+ default:
+ return ""
+ }
+}
+
+func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
+ var hour, minute, second int
+ _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
+ if err != nil {
+ return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err)
+ }
+ line = line[11:]
+
+ var cmd string
+ var prefix *irc.Prefix
+ var params []string
+ if events && strings.HasPrefix(line, "*** ") {
+ parts := strings.SplitN(line[4:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ switch parts[0] {
+ case "Joins:", "Parts:", "Quits:":
+ args := strings.SplitN(parts[1], " ", 3)
+ if len(args) < 2 {
+ return nil, time.Time{}, nil
+ }
+ nick := args[0]
+ mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ maskParts := strings.SplitN(mask, "@", 2)
+ if len(maskParts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ prefix = &irc.Prefix{
+ Name: nick,
+ User: maskParts[0],
+ Host: maskParts[1],
+ }
+ var reason string
+ if len(args) > 2 {
+ reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")")
+ }
+ switch parts[0] {
+ case "Joins:":
+ cmd = "JOIN"
+ params = []string{entity}
+ case "Parts:":
+ cmd = "PART"
+ if reason != "" {
+ params = []string{entity, reason}
+ } else {
+ params = []string{entity}
+ }
+ case "Quits:":
+ cmd = "QUIT"
+ if reason != "" {
+ params = []string{reason}
+ }
+ }
+ default:
+ nick := parts[0]
+ rem := parts[1]
+ if r := strings.TrimPrefix(rem, "is now known as "); r != rem {
+ cmd = "NICK"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = []string{r}
+ } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem {
+ args := strings.SplitN(r, " ", 2)
+ if len(args) != 2 {
+ return nil, time.Time{}, nil
+ }
+ cmd = "KICK"
+ prefix = &irc.Prefix{
+ Name: args[0],
+ }
+ reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ params = []string{entity, nick}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem {
+ cmd = "TOPIC"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'")
+ params = []string{entity, topic}
+ } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem {
+ cmd = "MODE"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = append([]string{entity}, strings.Split(r, " ")...)
+ } else {
+ return nil, time.Time{}, nil
+ }
+ }
+ } else {
+ var sender, text string
+ if strings.HasPrefix(line, "<") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[1:], "> ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "-") {
+ cmd = "NOTICE"
+ parts := strings.SplitN(line[1:], "- ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "* ") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[2:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01"
+ } else {
+ return nil, time.Time{}, nil
+ }
+
+ prefix = &irc.Prefix{Name: sender}
+ if entity == sender {
+ // This is a direct message from a user to us. We don't store own
+ // our nickname in the logs, so grab it from the network settings.
+ // Not very accurate since this may not match our nick at the time
+ // the message was received, but we can't do a lot better.
+ entity = GetNick(ms.user, network)
+ }
+ params = []string{entity, text}
+ }
+
+ year, month, day := ref.Date()
+ t := time.Date(year, month, day, hour, minute, second, 0, time.Local)
+
+ msg := &irc.Message{
+ Tags: map[string]irc.TagValue{
+ "time": irc.TagValue(formatServerTime(t)),
+ },
+ Prefix: prefix,
+ Command: cmd,
+ Params: params,
+ }
+ return msg, t, nil
+}
+
+func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages before ref: %v", err)
+ }
+ defer f.Close()
+
+ historyRing := make([]*irc.Message, limit)
+ cur := 0
+
+ sc := bufio.NewScanner(f)
+
+ if afterOffset >= 0 {
+ if _, err := f.Seek(afterOffset, io.SeekStart); err != nil {
+ return nil, nil
+ }
+ sc.Scan() // skip till next newline
+ }
+
+ for sc.Scan() {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(end) {
+ continue
+ } else if !t.Before(ref) {
+ break
+ }
+
+ historyRing[cur%limit] = msg
+ cur++
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err())
+ }
+
+ n := limit
+ if cur < limit {
+ n = cur
+ }
+ start := (cur - n + limit) % limit
+
+ if start+n <= limit { // ring doesnt wrap
+ return historyRing[start : start+n], nil
+ } else { // ring wraps
+ history := make([]*irc.Message, n)
+ r := copy(history, historyRing[start:])
+ copy(history[r:], historyRing[:n-r])
+ return history, nil
+ }
+}
+
+func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages after ref: %v", err)
+ }
+ defer f.Close()
+
+ var history []*irc.Message
+ sc := bufio.NewScanner(f)
+ for sc.Scan() && len(history) < limit {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(ref) {
+ continue
+ } else if !t.Before(end) {
+ break
+ }
+
+ history = append(history, msg)
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err())
+ }
+
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ history := make([]*irc.Message, limit)
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) {
+ buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ var history []*irc.Message
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) {
+ buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ history = append(history, buf...)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location())
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ var afterTime time.Time
+ var afterOffset int64
+ if id != "" {
+ var idNet int64
+ var idEntity string
+ var err error
+ idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+ if idNet != network.ID || idEntity != entity {
+ return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity")
+ }
+ }
+
+ history := make([]*irc.Message, limit)
+ t := time.Now()
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) {
+ var offset int64 = -1
+ if afterOffset >= 0 && truncateDay(t).Equal(afterTime) {
+ offset = afterOffset
+ }
+
+ buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := t.Date()
+ t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
+ root, err := os.Open(rootPath)
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+
+ // The returned targets are escaped, and there is no way to un-escape
+ // TODO: switch to ReadDir (Go 1.16+)
+ targetNames, err := root.Readdirnames(0)
+ root.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ var targets []chatHistoryTarget
+ for _, target := range targetNames {
+ // target is already escaped here
+ targetPath := filepath.Join(rootPath, target)
+ targetDir, err := os.Open(targetPath)
+ if err != nil {
+ return nil, err
+ }
+
+ entries, err := targetDir.Readdir(0)
+ targetDir.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ // We use mtime here, which may give imprecise or incorrect results
+ var t time.Time
+ for _, entry := range entries {
+ if entry.ModTime().After(t) {
+ t = entry.ModTime()
+ }
+ }
+
+ // The timestamps we get from logs have second granularity
+ t = truncateSecond(t)
+
+ // Filter out targets that don't fullfil the time bounds
+ if !isTimeBetween(t, start, end) {
+ continue
+ }
+
+ targets = append(targets, chatHistoryTarget{
+ Name: target,
+ LatestMessage: t,
+ })
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ // Sort targets by latest message time, backwards or forwards depending on
+ // the order of the time bounds
+ sort.Slice(targets, func(i, j int) bool {
+ t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage
+ if start.Before(end) {
+ return t1.Before(t2)
+ } else {
+ return !t1.Before(t2)
+ }
+ })
+
+ // Truncate the result if necessary
+ if len(targets) > limit {
+ targets = targets[:limit]
+ }
+
+ return targets, nil
+}
+
+func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
+ oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
+ newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
+ // Avoid loosing data by overwriting an existing directory
+ if _, err := os.Stat(newDir); err == nil {
+ return fmt.Errorf("destination %q already exists", newDir)
+ }
+ return os.Rename(oldDir, newDir)
+}
+
+func truncateDay(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, 0, 0, 0, 0, t.Location())
+}
+
+func truncateSecond(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location())
+}
+
+func isTimeBetween(t, start, end time.Time) bool {
+ if end.Before(start) {
+ end, start = start, end
+ }
+ return start.Before(t) && t.Before(end)
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const messageRingBufferCap = 4096
+
+type memoryMsgID struct {
+ Seq bare.Uint
+}
+
+func (memoryMsgID) msgIDType() msgIDType {
+ return msgIDMemory
+}
+
+func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
+ var id memoryMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", 0, err
+ }
+ return netID, entity, uint64(id.Seq), nil
+}
+
+func formatMemoryMsgID(netID int64, entity string, seq uint64) string {
+ id := memoryMsgID{bare.Uint(seq)}
+ return formatMsgID(netID, entity, &id)
+}
+
+type ringBufferKey struct {
+ networkID int64
+ entity string
+}
+
+type memoryMessageStore struct {
+ buffers map[ringBufferKey]*messageRingBuffer
+}
+
+var _ messageStore = (*memoryMessageStore)(nil)
+
+func newMemoryMessageStore() *memoryMessageStore {
+ return &memoryMessageStore{
+ buffers: make(map[ringBufferKey]*messageRingBuffer),
+ }
+}
+
+func (ms *memoryMessageStore) Close() error {
+ ms.buffers = nil
+ return nil
+}
+
+func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ return rb
+ }
+ rb := newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ return rb
+}
+
+func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ var seq uint64
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ seq = rb.cur
+ }
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ // Only append these messages, because LoadLatestID shouldn't return
+ // other kinds of message.
+ default:
+ return "", nil
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ rb = newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ }
+
+ seq := rb.Append(msg)
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ _, _, seq, err := parseMemoryMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ return nil, nil
+ }
+
+ return rb.LoadLatestSeq(seq, limit)
+}
+
+type messageRingBuffer struct {
+ buf []*irc.Message
+ cur uint64
+}
+
+func newMessageRingBuffer(capacity int) *messageRingBuffer {
+ return &messageRingBuffer{
+ buf: make([]*irc.Message, capacity),
+ cur: 1,
+ }
+}
+
+func (rb *messageRingBuffer) cap() uint64 {
+ return uint64(len(rb.buf))
+}
+
+func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 {
+ seq := rb.cur
+ i := int(seq % rb.cap())
+ rb.buf[i] = msg
+ rb.cur++
+ return seq
+}
+
+func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) {
+ if seq > rb.cur {
+ return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur)
+ } else if seq == rb.cur {
+ return nil, nil
+ }
+
+ // The query excludes the message with the sequence number seq
+ diff := rb.cur - seq - 1
+ if diff > rb.cap() {
+ // We dropped diff - cap entries
+ diff = rb.cap()
+ }
+ if int(diff) > limit {
+ diff = uint64(limit)
+ }
+
+ l := make([]*irc.Message, int(diff))
+ for i := 0; i < int(diff); i++ {
+ j := int((rb.cur - diff + uint64(i)) % rb.cap())
+ l[i] = rb.buf[j]
+ }
+
+ return l, nil
+}
--- /dev/null
+//go:build !go1.16
+// +build !go1.16
+
+package suika
+
+import (
+ "strings"
+)
+
+func isErrClosed(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "use of closed network connection")
+}
--- /dev/null
+//go:build go1.16
+// +build go1.16
+
+package suika
+
+import (
+ "errors"
+ "net"
+)
+
+func isErrClosed(err error) bool {
+ return errors.Is(err, net.ErrClosed)
+}
--- /dev/null
+package suika
+
+import (
+ "math/rand"
+ "time"
+)
+
+// backoffer implements a simple exponential backoff.
+type backoffer struct {
+ min, max, jitter time.Duration
+ n int64
+}
+
+func newBackoffer(min, max, jitter time.Duration) *backoffer {
+ return &backoffer{min: min, max: max, jitter: jitter}
+}
+
+func (b *backoffer) Reset() {
+ b.n = 0
+}
+
+func (b *backoffer) Next() time.Duration {
+ if b.n == 0 {
+ b.n = 1
+ return 0
+ }
+
+ d := time.Duration(b.n) * b.min
+ if d > b.max {
+ d = b.max
+ } else {
+ b.n *= 2
+ }
+
+ if b.jitter != 0 {
+ d += time.Duration(rand.Int63n(int64(b.jitter)))
+ }
+
+ return d
+}
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+desc="A drunk IRC bouncer"
+rcvar="suika_enable"
+
+: ${suika_user="ircd"}
+
+command="%%PREFIX%%/bin/suika"
+pidfile="/var/run/suika.pid"
+required_files="%%PREFIX%%/etc/suika/config"
+
+start_cmd="suika_start"
+
+suika_start() {
+ /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files}
+}
+
+load_rc_config "$name"
+run_rc_command "$1"
--- /dev/null
+# $TheSupernovaDuo$
+cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+user: ircd
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+rcvar="${name}"
+command="%%PREFIX/bin/${name}"
+command_args="--config %%PREFIX%%/etc/suika/config"
+pidfile="/var/run/${name}.pid"
+start_cmd="${name}_start"
+
+suika_start() {
+ printf "Starting %s..." "${name}"
+ ${command} ${command_args}
+ pgrep -n ${name} > ${pidfile}
+}
+
+load_rc_config ${name}
+run_rc_command "$1"
+
+
--- /dev/null
+#!/bin/ksh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+daemon="%%PREFIX%%/bin/suika"
+daemon_args="--config %%PREFIX%%/etc/suika/config"
+
+. /etc/rc.d/rc.subr
+
+rc_bg=YES
+
+rc_cmd "$1"
--- /dev/null
+# $TheSupernovaDuo$
+# vim: ft=confini
+[Unit]
+Description=A drunk IRC bouncer
+After=network.target
+Wants=network.target
+StartLimitBurst=5
+StartLimitIntervalSec=1
+[Service]
+Type=simple
+Restart=on-abnormal
+RestartSec=1
+User=suika
+ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+[Install]
+WantedBy=multi-user.target
--- /dev/null
+package suika
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "runtime/debug"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+// TODO: make configurable
+var (
+ retryConnectMinDelay = time.Minute
+ retryConnectMaxDelay = 10 * time.Minute
+ retryConnectJitter = time.Minute
+ connectTimeout = 15 * time.Second
+ writeTimeout = 10 * time.Second
+ upstreamMessageDelay = 2 * time.Second
+ upstreamMessageBurst = 10
+ backlogTimeout = 10 * time.Second
+ handleDownstreamMessageTimeout = 10 * time.Second
+ downstreamRegisterTimeout = 30 * time.Second
+ chatHistoryLimit = 1000
+ backlogLimit = 4000
+)
+
+type Logger interface {
+ Printf(format string, v ...interface{})
+ Debugf(format string, v ...interface{})
+}
+
+type logger struct {
+ *log.Logger
+ debug bool
+}
+
+func (l logger) Debugf(format string, v ...interface{}) {
+ if !l.debug {
+ return
+ }
+ l.Logger.Printf(format, v...)
+}
+
+func NewLogger(out io.Writer, debug bool) Logger {
+ return logger{
+ Logger: log.New(log.Writer(), "", log.LstdFlags),
+ debug: debug,
+ }
+}
+
+type prefixLogger struct {
+ logger Logger
+ prefix string
+}
+
+var _ Logger = (*prefixLogger)(nil)
+
+func (l *prefixLogger) Printf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Printf("%v"+format, v...)
+}
+
+func (l *prefixLogger) Debugf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Debugf("%v"+format, v...)
+}
+
+type int64Gauge struct {
+ v int64 // atomic
+}
+
+func (g *int64Gauge) Add(delta int64) {
+ atomic.AddInt64(&g.v, delta)
+}
+
+func (g *int64Gauge) Value() int64 {
+ return atomic.LoadInt64(&g.v)
+}
+
+func (g *int64Gauge) Float64() float64 {
+ return float64(g.Value())
+}
+
+type retryListener struct {
+ net.Listener
+ Logger Logger
+
+ delay time.Duration
+}
+
+func (ln *retryListener) Accept() (net.Conn, error) {
+ for {
+ conn, err := ln.Listener.Accept()
+ if ne, ok := err.(net.Error); ok && ne.Temporary() {
+ if ln.delay == 0 {
+ ln.delay = 5 * time.Millisecond
+ } else {
+ ln.delay *= 2
+ }
+ if max := 1 * time.Second; ln.delay > max {
+ ln.delay = max
+ }
+ if ln.Logger != nil {
+ ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err)
+ }
+ time.Sleep(ln.delay)
+ } else {
+ ln.delay = 0
+ return conn, err
+ }
+ }
+}
+
+type Config struct {
+ Hostname string
+ Title string
+ LogPath string
+ MaxUserNetworks int
+ MultiUpstream bool
+ MOTD string
+ UpstreamUserIPs []*net.IPNet
+}
+
+type Server struct {
+ Logger Logger
+
+ config atomic.Value // *Config
+ db Database
+ stopWG sync.WaitGroup
+
+ lock sync.Mutex
+ listeners map[net.Listener]struct{}
+ users map[string]*user
+}
+
+func NewServer(db Database) *Server {
+ srv := &Server{
+ Logger: NewLogger(log.Writer(), true),
+ db: db,
+ listeners: make(map[net.Listener]struct{}),
+ users: make(map[string]*user),
+ }
+ srv.config.Store(&Config{
+ Hostname: "localhost",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ })
+ return srv
+}
+
+func (s *Server) prefix() *irc.Prefix {
+ return &irc.Prefix{Name: s.Config().Hostname}
+}
+
+func (s *Server) Config() *Config {
+ return s.config.Load().(*Config)
+}
+
+func (s *Server) SetConfig(cfg *Config) {
+ s.config.Store(cfg)
+}
+
+func (s *Server) Start() error {
+ users, err := s.db.ListUsers(context.TODO())
+ if err != nil {
+ return err
+ }
+
+ s.lock.Lock()
+ for i := range users {
+ s.addUserLocked(&users[i])
+ }
+ s.lock.Unlock()
+
+ return nil
+}
+
+func (s *Server) Shutdown() {
+ s.lock.Lock()
+ for ln := range s.listeners {
+ if err := ln.Close(); err != nil {
+ s.Logger.Printf("failed to stop listener: %v", err)
+ }
+ }
+ for _, u := range s.users {
+ u.events <- eventStop{}
+ }
+ s.lock.Unlock()
+
+ s.stopWG.Wait()
+
+ if err := s.db.Close(); err != nil {
+ s.Logger.Printf("failed to close DB: %v", err)
+ }
+}
+
+func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if _, ok := s.users[user.Username]; ok {
+ return nil, fmt.Errorf("user %q already exists", user.Username)
+ }
+
+ err := s.db.StoreUser(ctx, user)
+ if err != nil {
+ return nil, fmt.Errorf("could not create user in db: %v", err)
+ }
+
+ return s.addUserLocked(user), nil
+}
+
+func (s *Server) forEachUser(f func(*user)) {
+ s.lock.Lock()
+ for _, u := range s.users {
+ f(u)
+ }
+ s.lock.Unlock()
+}
+
+func (s *Server) getUser(name string) *user {
+ s.lock.Lock()
+ u := s.users[name]
+ s.lock.Unlock()
+ return u
+}
+
+func (s *Server) addUserLocked(user *User) *user {
+ s.Logger.Printf("starting bouncer for user %q", user.Username)
+ u := newUser(s, user)
+ s.users[u.Username] = u
+
+ s.stopWG.Add(1)
+
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
+ }
+
+ s.lock.Lock()
+ delete(s.users, u.Username)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ u.run()
+ }()
+
+ return u
+}
+
+var lastDownstreamID uint64 = 0
+
+func (s *Server) handle(ic ircConn) {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
+ }
+ }()
+
+ id := atomic.AddUint64(&lastDownstreamID, 1)
+ dc := newDownstreamConn(s, ic, id)
+ if err := dc.runUntilRegistered(); err != nil {
+ if !errors.Is(err, io.EOF) {
+ dc.logger.Printf("%v", err)
+ }
+ } else {
+ dc.user.events <- eventDownstreamConnected{dc}
+ if err := dc.readMessages(dc.user.events); err != nil {
+ dc.logger.Printf("%v", err)
+ }
+ dc.user.events <- eventDownstreamDisconnected{dc}
+ }
+ dc.Close()
+}
+
+func (s *Server) Serve(ln net.Listener) error {
+ ln = &retryListener{
+ Listener: ln,
+ Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())},
+ }
+
+ s.lock.Lock()
+ s.listeners[ln] = struct{}{}
+ s.lock.Unlock()
+
+ s.stopWG.Add(1)
+
+ defer func() {
+ s.lock.Lock()
+ delete(s.listeners, ln)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ for {
+ conn, err := ln.Accept()
+ if isErrClosed(err) {
+ return nil
+ } else if err != nil {
+ return fmt.Errorf("failed to accept connection: %v", err)
+ }
+
+ go s.handle(newNetIRCConn(conn))
+ }
+}
+
+type ServerStats struct {
+ Users int
+ Downstreams int64
+ Upstreams int64
+}
+
+func (s *Server) Stats() *ServerStats {
+ var stats ServerStats
+ s.lock.Lock()
+ stats.Users = len(s.users)
+ s.lock.Unlock()
+ return &stats
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "net"
+ "testing"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+var testServerPrefix = &irc.Prefix{Name: "suika-test-server"}
+
+const (
+ testUsername = "suika-test-user"
+ testPassword = testUsername
+)
+
+func createTempSqliteDB(t *testing.T) Database {
+ db, err := OpenDB("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+ // :memory: will open a separate database for each new connection. Make
+ // sure the sql package only uses a single connection. An alternative
+ // solution is to use "file::memory:?cache=shared".
+ db.(*SqliteDB).db.SetMaxOpenConns(1)
+ return db
+}
+
+func createTempPostgresDB(t *testing.T) Database {
+ db := &PostgresDB{db: openTempPostgresDB(t)}
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("failed to upgrade PostgreSQL database: %v", err)
+ }
+
+ return db
+}
+
+func createTestUser(t *testing.T, db Database) *User {
+ hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
+ if err != nil {
+ t.Fatalf("failed to generate bcrypt hash: %v", err)
+ }
+
+ record := &User{Username: testUsername, Password: string(hashed)}
+ if err := db.StoreUser(context.Background(), record); err != nil {
+ t.Fatalf("failed to store test user: %v", err)
+ }
+
+ return record
+}
+
+func createTestDownstream(t *testing.T, srv *Server) ircConn {
+ c1, c2 := net.Pipe()
+ go srv.handle(newNetIRCConn(c1))
+ return newNetIRCConn(c2)
+}
+
+func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("failed to create TCP listener: %v", err)
+ }
+
+ network := &Network{
+ Name: "testnet",
+ Addr: "irc://" + ln.Addr().String(),
+ Nick: user.Username,
+ Enabled: true,
+ }
+ if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil {
+ t.Fatalf("failed to store test network: %v", err)
+ }
+
+ return network, ln
+}
+
+func mustAccept(t *testing.T, ln net.Listener) ircConn {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("failed accepting connection: %v", err)
+ }
+ return newNetIRCConn(c)
+}
+
+func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
+ msg, err := c.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message (want %q): %v", cmd, err)
+ }
+ if msg.Command != cmd {
+ t.Fatalf("invalid message received: want %q, got: %v", cmd, msg)
+ }
+ return msg
+}
+
+func registerDownstreamConn(t *testing.T, c ircConn, network *Network) {
+ c.WriteMessage(&irc.Message{
+ Command: "PASS",
+ Params: []string{testPassword},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "NICK",
+ Params: []string{testUsername},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "USER",
+ Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername},
+ })
+
+ expectMessage(t, c, irc.RPL_WELCOME)
+}
+
+func registerUpstreamConn(t *testing.T, c ircConn) {
+ msg := expectMessage(t, c, "CAP")
+ if msg.Params[0] != "LS" {
+ t.Fatalf("invalid CAP LS: got: %v", msg)
+ }
+ msg = expectMessage(t, c, "NICK")
+ nick := msg.Params[0]
+ if nick != testUsername {
+ t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg)
+ }
+ expectMessage(t, c, "USER")
+
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_WELCOME,
+ Params: []string{nick, "Welcome!"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_YOURHOST,
+ Params: []string{nick, "Your host is suika-test-server"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_CREATED,
+ Params: []string{nick, "Who cares when the server was created?"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_MYINFO,
+ Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.ERR_NOMOTD,
+ Params: []string{nick, "No MOTD"},
+ })
+}
+
+func testServer(t *testing.T, db Database) {
+ user := createTestUser(t, db)
+ network, upstream := createTestUpstream(t, db, user)
+ defer upstream.Close()
+
+ srv := NewServer(db)
+ if err := srv.Start(); err != nil {
+ t.Fatalf("failed to start server: %v", err)
+ }
+ defer srv.Shutdown()
+
+ uc := mustAccept(t, upstream)
+ defer uc.Close()
+ registerUpstreamConn(t, uc)
+
+ dc := createTestDownstream(t, srv)
+ defer dc.Close()
+ registerDownstreamConn(t, dc, network)
+
+ noticeText := "This is a very important server notice."
+ uc.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: "NOTICE",
+ Params: []string{testUsername, noticeText},
+ })
+
+ var msg *irc.Message
+ for {
+ var err error
+ msg, err = dc.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message: %v", err)
+ }
+ if msg.Command == "NOTICE" {
+ break
+ }
+ }
+
+ if msg.Params[1] != noticeText {
+ t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg)
+ }
+}
+
+func TestServer(t *testing.T) {
+ t.Run("sqlite", func(t *testing.T) {
+ db := createTempSqliteDB(t)
+ testServer(t, db)
+ })
+
+ t.Run("postgres", func(t *testing.T) {
+ db := createTempPostgresDB(t)
+ testServer(t, db)
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/hex"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ serviceNick = "BouncerServ"
+ serviceNickCM = "bouncerserv"
+ serviceRealname = "suika bouncer service"
+)
+
+// maxRSABits is the maximum number of RSA key bits used when generating a new
+// private key.
+const maxRSABits = 8192
+
+var servicePrefix = &irc.Prefix{
+ Name: serviceNick,
+ User: serviceNick,
+ Host: serviceNick,
+}
+
+type serviceCommandSet map[string]*serviceCommand
+
+type serviceCommand struct {
+ usage string
+ desc string
+ handle func(ctx context.Context, dc *downstreamConn, params []string) error
+ children serviceCommandSet
+ admin bool
+}
+
+func sendServiceNOTICE(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func sendServicePRIVMSG(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "PRIVMSG",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func splitWords(s string) ([]string, error) {
+ var words []string
+ var lastWord strings.Builder
+ escape := false
+ prev := ' '
+ wordDelim := ' '
+
+ for _, r := range s {
+ if escape {
+ // last char was a backslash, write the byte as-is.
+ lastWord.WriteRune(r)
+ escape = false
+ } else if r == '\\' {
+ escape = true
+ } else if wordDelim == ' ' && unicode.IsSpace(r) {
+ // end of last word
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ lastWord.Reset()
+ }
+ } else if r == wordDelim {
+ // wordDelim is either " or ', switch back to
+ // space-delimited words.
+ wordDelim = ' '
+ } else if r == '"' || r == '\'' {
+ if wordDelim == ' ' {
+ // start of (double-)quoted word
+ wordDelim = r
+ } else {
+ // either wordDelim is " and r is ' or vice-versa
+ lastWord.WriteRune(r)
+ }
+ } else {
+ lastWord.WriteRune(r)
+ }
+
+ prev = r
+ }
+
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ }
+
+ if wordDelim != ' ' {
+ return nil, fmt.Errorf("unterminated quoted string")
+ }
+ if escape {
+ return nil, fmt.Errorf("unterminated backslash sequence")
+ }
+
+ return words, nil
+}
+
+func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) {
+ words, err := splitWords(text)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err))
+ return
+ }
+
+ cmd, params, err := serviceCommands.Get(words)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err))
+ return
+ }
+ if cmd.admin && !dc.user.Admin {
+ sendServicePRIVMSG(dc, "error: you must be an admin to use this command")
+ return
+ }
+
+ if cmd.handle == nil {
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ // Pretend the command does not exist if it has neither children nor handler.
+ // This is obviously a bug but it is better to not die anyway.
+ dc.logger.Printf("command without handler and subcommands invoked:", words[0])
+ sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0]))
+ }
+ return
+ }
+
+ if err := cmd.handle(ctx, dc, params); err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err))
+ }
+}
+
+func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) {
+ if len(params) == 0 {
+ return nil, nil, fmt.Errorf("no command specified")
+ }
+
+ name := params[0]
+ params = params[1:]
+
+ cmd, ok := cmds[name]
+ if !ok {
+ for k := range cmds {
+ if !strings.HasPrefix(k, name) {
+ continue
+ }
+ if cmd != nil {
+ return nil, params, fmt.Errorf("command %q is ambiguous", name)
+ }
+ cmd = cmds[k]
+ }
+ }
+ if cmd == nil {
+ return nil, params, fmt.Errorf("command %q not found", name)
+ }
+
+ if len(params) == 0 || len(cmd.children) == 0 {
+ return cmd, params, nil
+ }
+ return cmd.children.Get(params)
+}
+
+func (cmds serviceCommandSet) Names() []string {
+ l := make([]string, 0, len(cmds))
+ for name := range cmds {
+ l = append(l, name)
+ }
+ sort.Strings(l)
+ return l
+}
+
+var serviceCommands serviceCommandSet
+
+func init() {
+ serviceCommands = serviceCommandSet{
+ "help": {
+ usage: "[command]",
+ desc: "print help message",
+ handle: handleServiceHelp,
+ },
+ "network": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "add a new network",
+ handle: handleServiceNetworkCreate,
+ },
+ "status": {
+ desc: "show a list of saved networks and their current status",
+ handle: handleServiceNetworkStatus,
+ },
+ "update": {
+ usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "update a network",
+ handle: handleServiceNetworkUpdate,
+ },
+ "delete": {
+ usage: "[name]",
+ desc: "delete a network",
+ handle: handleServiceNetworkDelete,
+ },
+ "quote": {
+ usage: "[name] <command>",
+ desc: "send a raw line to a network",
+ handle: handleServiceNetworkQuote,
+ },
+ },
+ },
+ "certfp": {
+ children: serviceCommandSet{
+ "generate": {
+ usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]",
+ desc: "generate a new self-signed certificate, defaults to using RSA-3072 key",
+ handle: handleServiceCertFPGenerate,
+ },
+ "fingerprint": {
+ usage: "[-network name]",
+ desc: "show fingerprints of certificate",
+ handle: handleServiceCertFPFingerprints,
+ },
+ },
+ },
+ "sasl": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show SASL status",
+ handle: handleServiceSASLStatus,
+ },
+ "set-plain": {
+ usage: "[-network name] <username> <password>",
+ desc: "set SASL PLAIN credentials",
+ handle: handleServiceSASLSetPlain,
+ },
+ "reset": {
+ usage: "[-network name]",
+ desc: "disable SASL authentication and remove stored credentials",
+ handle: handleServiceSASLReset,
+ },
+ },
+ },
+ "user": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-username <username> -password <password> [-realname <realname>] [-admin]",
+ desc: "create a new suika user",
+ handle: handleUserCreate,
+ admin: true,
+ },
+ "update": {
+ usage: "[-password <password>] [-realname <realname>]",
+ desc: "update the current user",
+ handle: handleUserUpdate,
+ },
+ "delete": {
+ usage: "<username>",
+ desc: "delete a user",
+ handle: handleUserDelete,
+ admin: true,
+ },
+ },
+ },
+ "channel": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show a list of saved channels and their current status",
+ handle: handleServiceChannelStatus,
+ },
+ "update": {
+ usage: "<name> [-relay-detached <default|none|highlight|message>] [-reattach-on <default|none|highlight|message>] [-detach-after <duration>] [-detach-on <default|none|highlight|message>]",
+ desc: "update a channel",
+ handle: handleServiceChannelUpdate,
+ },
+ },
+ },
+ "server": {
+ children: serviceCommandSet{
+ "status": {
+ desc: "show server statistics",
+ handle: handleServiceServerStatus,
+ admin: true,
+ },
+ "notice": {
+ desc: "broadcast a notice to all connected bouncer users",
+ handle: handleServiceServerNotice,
+ admin: true,
+ },
+ },
+ admin: true,
+ },
+ }
+}
+
+func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) {
+ for _, name := range cmds.Names() {
+ cmd := cmds[name]
+ if cmd.admin && !admin {
+ continue
+ }
+ words := append(prefix, name)
+ if len(cmd.children) == 0 {
+ s := strings.Join(words, " ")
+ *l = append(*l, s)
+ } else {
+ appendServiceCommandSetHelp(cmd.children, words, admin, l)
+ }
+ }
+}
+
+func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) > 0 {
+ cmd, rest, err := serviceCommands.Get(params)
+ if err != nil {
+ return err
+ }
+ words := params[:len(params)-len(rest)]
+
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ text := strings.Join(words, " ")
+ if cmd.usage != "" {
+ text += " " + cmd.usage
+ }
+ text += ": " + cmd.desc
+
+ sendServicePRIVMSG(dc, text)
+ }
+ } else {
+ var l []string
+ appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ }
+ return nil
+}
+
+func newFlagSet() *flag.FlagSet {
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.SetOutput(ioutil.Discard)
+ return fs
+}
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+// stringPtrFlag is a flag value populating a string pointer. This allows to
+// disambiguate between a flag that hasn't been set and a flag that has been
+// set to an empty string.
+type stringPtrFlag struct {
+ ptr **string
+}
+
+func (f stringPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return ""
+ }
+ return **f.ptr
+}
+
+func (f stringPtrFlag) Set(s string) error {
+ *f.ptr = &s
+ return nil
+}
+
+type boolPtrFlag struct {
+ ptr **bool
+}
+
+func (f boolPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return "<nil>"
+ }
+ return strconv.FormatBool(**f.ptr)
+}
+
+func (f boolPtrFlag) Set(s string) error {
+ v, err := strconv.ParseBool(s)
+ if err != nil {
+ return err
+ }
+ *f.ptr = &v
+ return nil
+}
+
+func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) {
+ name, params := popArg(params)
+ if name == "" {
+ if dc.network == nil {
+ return nil, params, fmt.Errorf("no network selected, a name argument is required")
+ }
+ return dc.network, params, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, params, fmt.Errorf("unknown network %q", name)
+ }
+ return net, params, nil
+ }
+}
+
+type networkFlagSet struct {
+ *flag.FlagSet
+ Addr, Name, Nick, Username, Pass, Realname *string
+ Enabled *bool
+ ConnectCommands []string
+}
+
+func newNetworkFlagSet() *networkFlagSet {
+ fs := &networkFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
+ fs.Var(stringPtrFlag{&fs.Name}, "name", "")
+ fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
+ fs.Var(stringPtrFlag{&fs.Username}, "username", "")
+ fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
+ fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
+ fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "")
+ fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
+ return fs
+}
+
+func (fs *networkFlagSet) update(network *Network) error {
+ if fs.Addr != nil {
+ if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
+ scheme := addrParts[0]
+ switch scheme {
+ case "ircs", "irc", "unix":
+ default:
+ return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme)
+ }
+ }
+ network.Addr = *fs.Addr
+ }
+ if fs.Name != nil {
+ network.Name = *fs.Name
+ }
+ if fs.Nick != nil {
+ network.Nick = *fs.Nick
+ }
+ if fs.Username != nil {
+ network.Username = *fs.Username
+ }
+ if fs.Pass != nil {
+ network.Pass = *fs.Pass
+ }
+ if fs.Realname != nil {
+ network.Realname = *fs.Realname
+ }
+ if fs.Enabled != nil {
+ network.Enabled = *fs.Enabled
+ }
+ if fs.ConnectCommands != nil {
+ if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
+ network.ConnectCommands = nil
+ } else {
+ for _, command := range fs.ConnectCommands {
+ _, err := irc.ParseMessage(command)
+ if err != nil {
+ return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
+ }
+ }
+ network.ConnectCommands = fs.ConnectCommands
+ }
+ }
+ return nil
+}
+
+func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if fs.Addr == nil {
+ return fmt.Errorf("flag -addr is required")
+ }
+
+ record := &Network{
+ Addr: *fs.Addr,
+ Enabled: true,
+ }
+ if err := fs.update(record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return fmt.Errorf("could not create network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ n := 0
+ for _, net := range dc.user.networks {
+ var statuses []string
+ var details string
+ if uc := net.conn; uc != nil {
+ if dc.nick != uc.nick {
+ statuses = append(statuses, "connected as "+uc.nick)
+ } else {
+ statuses = append(statuses, "connected")
+ }
+ details = fmt.Sprintf("%v channels", uc.channels.Len())
+ } else if !net.Enabled {
+ statuses = append(statuses, "disabled")
+ } else {
+ statuses = append(statuses, "disconnected")
+ if net.lastError != nil {
+ details = net.lastError.Error()
+ }
+ }
+
+ if net == dc.network {
+ statuses = append(statuses, "current")
+ }
+
+ name := net.GetName()
+ if name != net.Addr {
+ name = fmt.Sprintf("%v (%v)", name, net.Addr)
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", "))
+ if details != "" {
+ s += ": " + details
+ }
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, `No network configured, add one with "network create".`)
+ }
+
+ return nil
+}
+
+func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := fs.update(&record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return fmt.Errorf("could not update network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName()))
+ return nil
+}
+
+func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 && len(params) != 2 {
+ return fmt.Errorf("expected one or two arguments")
+ }
+
+ raw := params[len(params)-1]
+ params = params[:len(params)-1]
+
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ uc := net.conn
+ if uc == nil {
+ return fmt.Errorf("network %q is not currently connected", net.GetName())
+ }
+
+ m, err := irc.ParseMessage(raw)
+ if err != nil {
+ return fmt.Errorf("failed to parse command %q: %v", raw, err)
+ }
+ uc.SendMessage(ctx, m)
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName()))
+ return nil
+}
+
+func sendCertfpFingerprints(dc *downstreamConn, cert []byte) {
+ sha1Sum := sha1.Sum(cert)
+ sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
+ sha256Sum := sha256.Sum256(cert)
+ sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))
+ sha512Sum := sha512.Sum512(cert)
+ sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:]))
+}
+
+func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) {
+ if name == "" {
+ if dc.network == nil {
+ return nil, fmt.Errorf("no network selected, -network is required")
+ }
+ return dc.network, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, fmt.Errorf("unknown network %q", name)
+ }
+ return net, nil
+ }
+}
+
+func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+ keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)")
+ bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if *bits <= 0 || *bits > maxRSABits {
+ return fmt.Errorf("invalid value for -bits")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ privKey, cert, err := generateCertFP(*keyType, *bits)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.External.CertBlob = cert
+ net.SASL.External.PrivKeyBlob = privKey
+ net.SASL.Mechanism = "EXTERNAL"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "certificate generated")
+ sendCertfpFingerprints(dc, cert)
+ return nil
+}
+
+func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ if net.SASL.Mechanism != "EXTERNAL" {
+ return fmt.Errorf("CertFP not set up")
+ }
+
+ sendCertfpFingerprints(dc, net.SASL.External.CertBlob)
+ return nil
+}
+
+func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ switch net.SASL.Mechanism {
+ case "PLAIN":
+ sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username))
+ case "EXTERNAL":
+ sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled")
+ case "":
+ sendServicePRIVMSG(dc, "SASL is disabled")
+ }
+
+ if uc := net.conn; uc != nil {
+ if uc.account != "" {
+ sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account))
+ } else {
+ sendServicePRIVMSG(dc, "Unauthenticated on upstream network")
+ }
+ } else {
+ sendServicePRIVMSG(dc, "Disconnected from upstream network")
+ }
+
+ return nil
+}
+
+func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if len(fs.Args()) != 2 {
+ return fmt.Errorf("expected exactly 2 arguments")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = fs.Arg(0)
+ net.SASL.Plain.Password = fs.Arg(1)
+ net.SASL.Mechanism = "PLAIN"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials saved")
+ return nil
+}
+
+func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = ""
+ net.SASL.Plain.Password = ""
+ net.SASL.External.CertBlob = nil
+ net.SASL.External.PrivKeyBlob = nil
+ net.SASL.Mechanism = ""
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials reset")
+ return nil
+}
+
+func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ username := fs.String("username", "", "")
+ password := fs.String("password", "", "")
+ realname := fs.String("realname", "", "")
+ admin := fs.Bool("admin", false, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if *username == "" {
+ return fmt.Errorf("flag -username is required")
+ }
+ if *password == "" {
+ return fmt.Errorf("flag -password is required")
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+
+ user := &User{
+ Username: *username,
+ Password: string(hashed),
+ Realname: *realname,
+ Admin: *admin,
+ }
+ if _, err := dc.srv.createUser(ctx, user); err != nil {
+ return fmt.Errorf("could not create user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username))
+ return nil
+}
+
+func popArg(params []string) (string, []string) {
+ if len(params) > 0 && !strings.HasPrefix(params[0], "-") {
+ return params[0], params[1:]
+ }
+ return "", params
+}
+
+func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ var password, realname *string
+ var admin *bool
+ fs := newFlagSet()
+ fs.Var(stringPtrFlag{&password}, "password", "")
+ fs.Var(stringPtrFlag{&realname}, "realname", "")
+ fs.Var(boolPtrFlag{&admin}, "admin", "")
+
+ username, params := popArg(params)
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if len(fs.Args()) > 0 {
+ return fmt.Errorf("unexpected argument")
+ }
+
+ var hashed *string
+ if password != nil {
+ hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+ hashedStr := string(hashedBytes)
+ hashed = &hashedStr
+ }
+
+ if username != "" && username != dc.user.Username {
+ if !dc.user.Admin {
+ return fmt.Errorf("you must be an admin to update other users")
+ }
+ if realname != nil {
+ return fmt.Errorf("cannot update -realname of other user")
+ }
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ done := make(chan error, 1)
+ event := eventUserUpdate{
+ password: hashed,
+ admin: admin,
+ done: done,
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case u.events <- event:
+ }
+ // TODO: send context to the other side
+ if err := <-done; err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username))
+ } else {
+ // copy the user record because we'll mutate it
+ record := dc.user.User
+
+ if hashed != nil {
+ record.Password = *hashed
+ }
+ if realname != nil {
+ record.Realname = *realname
+ }
+ if admin != nil {
+ return fmt.Errorf("cannot update -admin of own user")
+ }
+
+ if err := dc.user.updateUser(ctx, &record); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username))
+ }
+
+ return nil
+}
+
+func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ username := params[0]
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ u.stop()
+
+ if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil {
+ return fmt.Errorf("failed to delete user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username))
+ return nil
+}
+
+func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ var defaultNetworkName string
+ if dc.network != nil {
+ defaultNetworkName = dc.network.GetName()
+ }
+
+ fs := newFlagSet()
+ networkName := fs.String("network", defaultNetworkName, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ n := 0
+
+ sendNetwork := func(net *network) {
+ var channels []*Channel
+ for _, entry := range net.channels.innerMap {
+ channels = append(channels, entry.value.(*Channel))
+ }
+
+ sort.Slice(channels, func(i, j int) bool {
+ return strings.ReplaceAll(channels[i].Name, "#", "") <
+ strings.ReplaceAll(channels[j].Name, "#", "")
+ })
+
+ for _, ch := range channels {
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+ }
+
+ name := ch.Name
+ if *networkName == "" {
+ name += "/" + net.GetName()
+ }
+
+ var status string
+ if uch != nil {
+ status = "joined"
+ } else if net.conn != nil {
+ status = "parted"
+ } else {
+ status = "disconnected"
+ }
+
+ if ch.Detached {
+ status += ", detached"
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, status)
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+ }
+
+ if *networkName == "" {
+ for _, net := range dc.user.networks {
+ sendNetwork(net)
+ }
+ } else {
+ net := dc.user.getNetwork(*networkName)
+ if net == nil {
+ return fmt.Errorf("unknown network %q", *networkName)
+ }
+ sendNetwork(net)
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, "No channel configured.")
+ }
+
+ return nil
+}
+
+type channelFlagSet struct {
+ *flag.FlagSet
+ RelayDetached, ReattachOn, DetachAfter, DetachOn *string
+}
+
+func newChannelFlagSet() *channelFlagSet {
+ fs := &channelFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "")
+ fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "")
+ fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "")
+ fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "")
+ return fs
+}
+
+func (fs *channelFlagSet) update(channel *Channel) error {
+ if fs.RelayDetached != nil {
+ filter, err := parseFilter(*fs.RelayDetached)
+ if err != nil {
+ return err
+ }
+ channel.RelayDetached = filter
+ }
+ if fs.ReattachOn != nil {
+ filter, err := parseFilter(*fs.ReattachOn)
+ if err != nil {
+ return err
+ }
+ channel.ReattachOn = filter
+ }
+ if fs.DetachAfter != nil {
+ dur, err := time.ParseDuration(*fs.DetachAfter)
+ if err != nil || dur < 0 {
+ return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter)
+ }
+ channel.DetachAfter = dur
+ }
+ if fs.DetachOn != nil {
+ filter, err := parseFilter(*fs.DetachOn)
+ if err != nil {
+ return err
+ }
+ channel.DetachOn = filter
+ }
+ return nil
+}
+
+func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) < 1 {
+ return fmt.Errorf("expected at least one argument")
+ }
+ name := params[0]
+
+ fs := newChannelFlagSet()
+ if err := fs.Parse(params[1:]); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ if err := fs.update(ch); err != nil {
+ return err
+ }
+
+ uc.updateChannelAutoDetach(upstreamName)
+
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ return fmt.Errorf("failed to update channel: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name))
+ return nil
+}
+func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ dbStats, err := dc.user.srv.db.Stats(ctx)
+ if err != nil {
+ return err
+ }
+ serverStats := dc.user.srv.Stats()
+ sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels))
+ return nil
+}
+
+func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ text := params[0]
+
+ dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text)
+
+ broadcastMsg := &irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{"$" + dc.srv.Config().Hostname, text},
+ }
+ var err error
+ sent := 0
+ total := 0
+ dc.srv.forEachUser(func(u *user) {
+ total++
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ case u.events <- eventBroadcast{broadcastMsg}:
+ sent++
+ }
+ })
+
+ dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total)
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total))
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func assertSplit(t *testing.T, input string, expected []string) {
+ actual, err := splitWords(input)
+ if err != nil {
+ t.Errorf("%q: %v", input, err)
+ return
+ }
+ if len(actual) != len(expected) {
+ t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual)
+ return
+ }
+ for i := 0; i < len(actual); i++ {
+ if actual[i] != expected[i] {
+ t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual)
+ }
+ }
+}
+
+func TestSplit(t *testing.T) {
+ assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{
+ "ch",
+ "up",
+ "#suika",
+ "relay-detached",
+ "message",
+ })
+ assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{
+ "net",
+ "update",
+ "\"free\"node",
+ "-pass",
+ "political \"stance\" desu!",
+ "-realname",
+ "",
+ "-nick",
+ "lee",
+ })
+ assertSplit(t, "Omedeto,\\ Yui! ''", []string{
+ "Omedeto, Yui!",
+ "",
+ })
+
+ if _, err := splitWords("end of 'file"); err == nil {
+ t.Errorf("expected error on unterminated single quote")
+ }
+ if _, err := splitWords("end of backquote \\"); err == nil {
+ t.Errorf("expected error on unterminated backquote sequence")
+ }
+}
--- /dev/null
+CREATE TABLE IF NOT EXISTS "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+
+CREATE TABLE IF NOT EXISTS "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255),
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism sasl_mechanism,
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA,
+ sasl_external_key BYTEA,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+CREATE TABLE IF NOT EXISTS "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+CREATE TABLE IF NOT EXISTS "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+);
+
--- /dev/null
+CREATE TABLE IF NOT EXISTS User (
+ id INTEGER PRIMARY KEY,
+ username TEXT NOT NULL UNIQUE,
+ password TEXT,
+ admin INTEGER NOT NULL DEFAULT 0,
+ realname TEXT
+);
+CREATE TABLE IF NOT EXISTS Network (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+CREATE TABLE IF NOT EXISTS Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name TEXT NOT NULL,
+ key TEXT,
+ detached INTEGER NOT NULL DEFAULT 0,
+ detached_internal_msgid TEXT,
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ client TEXT,
+ internal_msgid TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+);
+
+CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+);
+
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto"
+ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "gopkg.in/irc.v3"
+)
+
+// permanentUpstreamCaps is the static list of upstream capabilities always
+// requested when supported.
+var permanentUpstreamCaps = map[string]bool{
+ "account-notify": true,
+ "account-tag": true,
+ "away-notify": true,
+ "batch": true,
+ "extended-join": true,
+ "invite-notify": true,
+ "labeled-response": true,
+ "message-tags": true,
+ "multi-prefix": true,
+ "sasl": true,
+ "server-time": true,
+ "setname": true,
+
+ "draft/account-registration": true,
+ "draft/extended-monitor": true,
+}
+
+type registrationError struct {
+ *irc.Message
+}
+
+func (err registrationError) Error() string {
+ return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason())
+}
+
+func (err registrationError) Reason() string {
+ if len(err.Params) > 0 {
+ return err.Params[len(err.Params)-1]
+ }
+ return err.Command
+}
+
+func (err registrationError) Temporary() bool {
+ // Only return false if we're 100% sure that fixing the error requires a
+ // network configuration change
+ switch err.Command {
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME:
+ return false
+ case "FAIL":
+ return err.Params[1] != "ACCOUNT_REQUIRED"
+ default:
+ return true
+ }
+}
+
+type upstreamChannel struct {
+ Name string
+ conn *upstreamConn
+ Topic string
+ TopicWho *irc.Prefix
+ TopicTime time.Time
+ Status channelStatus
+ modes channelModes
+ creationTime string
+ Members membershipsCasemapMap
+ complete bool
+ detachTimer *time.Timer
+}
+
+func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) {
+ if uc.detachTimer != nil {
+ uc.detachTimer.Stop()
+ uc.detachTimer = nil
+ }
+
+ if dur == 0 {
+ return
+ }
+
+ uc.detachTimer = time.AfterFunc(dur, func() {
+ uc.conn.network.user.events <- eventChannelDetach{
+ uc: uc.conn,
+ name: uc.Name,
+ }
+ })
+}
+
+type pendingUpstreamCommand struct {
+ downstreamID uint64
+ msg *irc.Message
+}
+
+type upstreamConn struct {
+ conn
+
+ network *network
+ user *user
+
+ serverName string
+ availableUserModes string
+ availableChannelModes map[byte]channelModeType
+ availableChannelTypes string
+ availableMemberships []membership
+ isupport map[string]*string
+
+ registered bool
+ nick string
+ nickCM string
+ username string
+ realname string
+ modes userModes
+ channels upstreamChannelCasemapMap
+ supportedCaps map[string]string
+ caps map[string]bool
+ batches map[string]batch
+ away bool
+ account string
+ nextLabelID uint64
+ monitored monitorCasemapMap
+
+ saslClient sasl.Client
+ saslStarted bool
+
+ casemapIsSet bool
+
+ // Queue of commands in progress, indexed by type. The first entry has been
+ // sent to the server and is awaiting reply. The following entries have not
+ // been sent yet.
+ pendingCmds map[string][]pendingUpstreamCommand
+
+ gotMotd bool
+}
+
+func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
+ logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
+
+ dialer := net.Dialer{Timeout: connectTimeout}
+
+ u, err := network.URL()
+ if err != nil {
+ return nil, err
+ }
+
+ var netConn net.Conn
+ switch u.Scheme {
+ case "ircs":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6697"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to TLS server at address %q", addr)
+
+ tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}}
+ if network.SASL.Mechanism == "EXTERNAL" {
+ if network.SASL.External.CertBlob == nil {
+ return nil, fmt.Errorf("missing certificate for authentication")
+ }
+ if network.SASL.External.PrivKeyBlob == nil {
+ return nil, fmt.Errorf("missing private key for authentication")
+ }
+ key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse private key: %v", err)
+ }
+ tlsConfig.Certificates = []tls.Certificate{
+ {
+ Certificate: [][]byte{network.SASL.External.CertBlob},
+ PrivateKey: key.(crypto.PrivateKey),
+ },
+ }
+ logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
+ }
+
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+
+ // Don't do the TLS handshake immediately, because we need to register
+ // the new connection with identd ASAP.
+ netConn = tls.Client(netConn, tlsConfig)
+ case "irc":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6667"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to plain-text server at address %q", addr)
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+ case "irc+unix", "unix":
+ logger.Printf("connecting to Unix socket at path %q", u.Path)
+ netConn, err = dialer.DialContext(ctx, "unix", u.Path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
+ }
+ default:
+ return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
+ }
+
+ options := connOptions{
+ Logger: logger,
+ RateLimitDelay: upstreamMessageDelay,
+ RateLimitBurst: upstreamMessageBurst,
+ }
+
+ uc := &upstreamConn{
+ conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
+ network: network,
+ user: network.user,
+ channels: upstreamChannelCasemapMap{newCasemapMap(0)},
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ batches: make(map[string]batch),
+ availableChannelTypes: stdChannelTypes,
+ availableChannelModes: stdChannelModes,
+ availableMemberships: stdMemberships,
+ isupport: make(map[string]*string),
+ pendingCmds: make(map[string][]pendingUpstreamCommand),
+ monitored: monitorCasemapMap{newCasemapMap(0)},
+ }
+ return uc, nil
+}
+
+func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
+ uc.network.forEachDownstream(f)
+}
+
+func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if id != 0 && id != dc.id {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
+ for _, dc := range uc.user.downstreamConns {
+ if dc.id == id {
+ return dc
+ }
+ }
+ return nil
+}
+
+func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ return nil, fmt.Errorf("unknown channel %q", name)
+ }
+ return ch, nil
+}
+
+func (uc *upstreamConn) isChannel(entity string) bool {
+ return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0]))
+}
+
+func (uc *upstreamConn) isOurNick(nick string) bool {
+ return uc.nickCM == uc.network.casemap(nick)
+}
+
+func (uc *upstreamConn) abortPendingCommands() {
+ for _, l := range uc.pendingCmds {
+ for _, pendingCmd := range l {
+ dc := uc.downstreamByID(pendingCmd.downstreamID)
+ if dc == nil {
+ continue
+ }
+
+ switch pendingCmd.msg.Command {
+ case "LIST":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Command aborted"},
+ })
+ case "WHO":
+ mask := "*"
+ if len(pendingCmd.msg.Params) > 0 {
+ mask = pendingCmd.msg.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "Command aborted"},
+ })
+ case "AUTHENTICATE":
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ case "REGISTER", "VERIFY":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "FAIL",
+ Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"},
+ })
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command))
+ }
+ }
+ }
+
+ uc.pendingCmds = make(map[string][]pendingUpstreamCommand)
+}
+
+func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return
+ }
+ uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg)
+}
+
+func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
+ switch msg.Command {
+ case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY":
+ // Supported
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", msg.Command))
+ }
+
+ uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{
+ downstreamID: dc.id,
+ msg: msg,
+ })
+
+ if len(uc.pendingCmds[msg.Command]) == 1 {
+ uc.sendNextPendingCommand(msg.Command)
+ }
+}
+
+func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return nil, nil
+ }
+
+ pendingCmd := uc.pendingCmds[cmd][0]
+ return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg
+}
+
+func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) {
+ dc, msg := uc.currentPendingCommand(cmd)
+
+ if len(uc.pendingCmds[cmd]) > 0 {
+ copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:])
+ uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1]
+ }
+
+ uc.sendNextPendingCommand(cmd)
+
+ return dc, msg
+}
+
+func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) {
+ for cmd := range uc.pendingCmds {
+ // We can't cancel the currently running command stored in
+ // uc.pendingCmds[cmd][0]
+ for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- {
+ if uc.pendingCmds[cmd][i].downstreamID == downstreamID {
+ uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...)
+ }
+ }
+ }
+}
+
+func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) {
+ memberships := make(memberships, 0, 4)
+ i := 0
+ for _, m := range uc.availableMemberships {
+ if i >= len(s) {
+ break
+ }
+ if s[i] == m.Prefix {
+ memberships = append(memberships, m)
+ i++
+ }
+ }
+ return &memberships, s[i:]
+}
+
+func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ var label string
+ if l, ok := msg.GetTag("label"); ok {
+ label = l
+ delete(msg.Tags, "label")
+ }
+
+ var msgBatch *batch
+ if batchName, ok := msg.GetTag("batch"); ok {
+ b, ok := uc.batches[batchName]
+ if !ok {
+ return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
+ }
+ msgBatch = &b
+ if label == "" {
+ label = msgBatch.Label
+ }
+ delete(msg.Tags, "batch")
+ }
+
+ var downstreamID uint64 = 0
+ if label != "" {
+ var labelOffset uint64
+ n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset)
+ if err == nil && n < 2 {
+ err = errors.New("not enough arguments")
+ }
+ if err != nil {
+ return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
+ }
+ }
+
+ if _, ok := msg.Tags["time"]; !ok {
+ msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ }
+
+ switch msg.Command {
+ case "PING":
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PONG",
+ Params: msg.Params,
+ })
+ return nil
+ case "NOTICE", "PRIVMSG", "TAGMSG":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var entity, text string
+ if msg.Command != "TAGMSG" {
+ if err := parseMessageParams(msg, &entity, &text); err != nil {
+ return err
+ }
+ } else {
+ if err := parseMessageParams(msg, &entity); err != nil {
+ return err
+ }
+ }
+
+ if msg.Prefix.Name == serviceNick {
+ uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg)
+ break
+ }
+ if entity == serviceNick {
+ uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg)
+ break
+ }
+
+ if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
+ uc.produce("", msg, nil)
+ } else { // regular user message
+ target := entity
+ if uc.isOurNick(target) {
+ target = msg.Prefix.Name
+ }
+
+ ch := uc.network.channels.Value(target)
+ if ch != nil && msg.Command != "TAGMSG" {
+ if ch.Detached {
+ uc.handleDetachedMessage(ctx, ch, msg)
+ }
+
+ highlight := uc.network.isHighlight(msg)
+ if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) {
+ uc.updateChannelAutoDetach(target)
+ }
+ }
+
+ uc.produce(target, msg, nil)
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, nil, &subCmd); err != nil {
+ return err
+ }
+ subCmd = strings.ToUpper(subCmd)
+ subParams := msg.Params[2:]
+ switch subCmd {
+ case "LS":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := subParams[len(subParams)-1]
+ more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
+
+ uc.handleSupportedCaps(caps)
+
+ if more {
+ break // wait to receive all capabilities
+ }
+
+ uc.requestCaps()
+
+ if uc.requestSASL() {
+ break // we'll send CAP END after authentication is completed
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ case "ACK", "NAK":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, name := range caps {
+ if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil {
+ return err
+ }
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ case "NEW":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ uc.handleSupportedCaps(subParams[0])
+ uc.requestCaps()
+ case "DEL":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, c := range caps {
+ delete(uc.supportedCaps, c)
+ delete(uc.caps, c)
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+ }
+ case "AUTHENTICATE":
+ if uc.saslClient == nil {
+ return fmt.Errorf("received unexpected AUTHENTICATE message")
+ }
+
+ // TODO: if a challenge is 400 bytes long, buffer it
+ var challengeStr string
+ if err := parseMessageParams(msg, &challengeStr); err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ var challenge []byte
+ if challengeStr != "+" {
+ var err error
+ challenge, err = base64.StdEncoding.DecodeString(challengeStr)
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+ }
+
+ var resp []byte
+ var err error
+ if !uc.saslStarted {
+ _, resp, err = uc.saslClient.Start()
+ uc.saslStarted = true
+ } else {
+ resp, err = uc.saslClient.Next(challenge)
+ }
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ // <= instead of < because we need to send a final empty response if
+ // the last chunk is exactly 400 bytes long
+ for i := 0; i <= len(resp); i += maxSASLLength {
+ j := i + maxSASLLength
+ if j > len(resp) {
+ j = len(resp)
+ }
+
+ chunk := resp[i:j]
+
+ var respStr = "+"
+ if len(chunk) != 0 {
+ respStr = base64.StdEncoding.EncodeToString(chunk)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{respStr},
+ })
+ }
+ case irc.RPL_LOGGEDIN:
+ if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil {
+ return err
+ }
+ uc.logger.Printf("logged in with account %q", uc.account)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.RPL_LOGGEDOUT:
+ uc.account = ""
+ uc.logger.Printf("logged out")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
+ var info string
+ if err := parseMessageParams(msg, nil, &info); err != nil {
+ return err
+ }
+ switch msg.Command {
+ case irc.ERR_NICKLOCKED:
+ uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
+ case irc.ERR_SASLFAIL:
+ uc.logger.Printf("SASL authentication failed: %v", info)
+ case irc.ERR_SASLTOOLONG:
+ uc.logger.Printf("SASL message too long: %v", info)
+ }
+
+ uc.saslClient = nil
+ uc.saslStarted = false
+
+ if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
+ if msg.Command == irc.RPL_SASLSUCCESS {
+ uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword)
+ }
+
+ dc.endSASL(msg)
+ }
+
+ if !uc.registered {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil {
+ if msg.Command == "REGISTER" {
+ var account, password string
+ if err := parseMessageParams(msg, nil, &account); err != nil {
+ return err
+ }
+ if err := parseMessageParams(cmd, nil, nil, &password); err != nil {
+ return err
+ }
+ uc.network.autoSaveSASLPlain(ctx, account, password)
+ }
+
+ dc.SendMessage(msg)
+ }
+ case irc.RPL_WELCOME:
+ if err := parseMessageParams(msg, &uc.nick); err != nil {
+ return err
+ }
+
+ uc.registered = true
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("connection registered with nick %q", uc.nick)
+
+ if uc.network.channels.Len() > 0 {
+ var channels, keys []string
+ for _, entry := range uc.network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, ch.Name)
+ keys = append(keys, ch.Key)
+ }
+
+ for _, msg := range join(channels, keys) {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+ case irc.RPL_MYINFO:
+ if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
+ return err
+ }
+ case irc.RPL_ISUPPORT:
+ if err := parseMessageParams(msg, nil, nil); err != nil {
+ return err
+ }
+
+ var downstreamIsupport []string
+ for _, token := range msg.Params[1 : len(msg.Params)-1] {
+ parameter := token
+ var negate, hasValue bool
+ var value string
+ if strings.HasPrefix(token, "-") {
+ negate = true
+ token = token[1:]
+ } else if i := strings.IndexByte(token, '='); i >= 0 {
+ parameter = token[:i]
+ value = token[i+1:]
+ hasValue = true
+ }
+
+ if hasValue {
+ uc.isupport[parameter] = &value
+ } else if !negate {
+ uc.isupport[parameter] = nil
+ } else {
+ delete(uc.isupport, parameter)
+ }
+
+ var err error
+ switch parameter {
+ case "CASEMAPPING":
+ casemap, ok := parseCasemappingToken(value)
+ if !ok {
+ casemap = casemapRFC1459
+ }
+ uc.network.updateCasemapping(casemap)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.casemapIsSet = true
+ case "CHANMODES":
+ if !negate {
+ err = uc.handleChanModes(value)
+ } else {
+ uc.availableChannelModes = stdChannelModes
+ }
+ case "CHANTYPES":
+ if !negate {
+ uc.availableChannelTypes = value
+ } else {
+ uc.availableChannelTypes = stdChannelTypes
+ }
+ case "PREFIX":
+ if !negate {
+ err = uc.handleMemberships(value)
+ } else {
+ uc.availableMemberships = stdMemberships
+ }
+ }
+ if err != nil {
+ return err
+ }
+
+ if passthroughIsupport[parameter] {
+ downstreamIsupport = append(downstreamIsupport, token)
+ }
+ }
+
+ uc.updateMonitor()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil {
+ return
+ }
+ msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport)
+ for _, msg := range msgs {
+ dc.SendMessage(msg)
+ }
+ })
+ case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD:
+ if !uc.casemapIsSet {
+ // upstream did not send any CASEMAPPING token, thus
+ // we assume it implements the old RFCs with rfc1459.
+ uc.casemapIsSet = true
+ uc.network.updateCasemapping(casemapRFC1459)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ if !uc.gotMotd {
+ // Ignore the initial MOTD upon connection, but forward
+ // subsequent MOTD messages downstream
+ uc.gotMotd = true
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case "BATCH":
+ var tag string
+ if err := parseMessageParams(msg, &tag); err != nil {
+ return err
+ }
+
+ if strings.HasPrefix(tag, "+") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; ok {
+ return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
+ }
+ var batchType string
+ if err := parseMessageParams(msg, nil, &batchType); err != nil {
+ return err
+ }
+ label := label
+ if label == "" && msgBatch != nil {
+ label = msgBatch.Label
+ }
+ uc.batches[tag] = batch{
+ Type: batchType,
+ Params: msg.Params[2:],
+ Outer: msgBatch,
+ Label: label,
+ }
+ } else if strings.HasPrefix(tag, "-") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; !ok {
+ return fmt.Errorf("unknown BATCH reference tag: %q", tag)
+ }
+ delete(uc.batches, tag)
+ } else {
+ return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
+ }
+ case "NICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newNick string
+ if err := parseMessageParams(msg, &newNick); err != nil {
+ return err
+ }
+
+ me := false
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
+ me = true
+ uc.nick = newNick
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ memberships := ch.Members.Value(msg.Prefix.Name)
+ if memberships != nil {
+ ch.Members.Delete(msg.Prefix.Name)
+ ch.Members.SetValue(newNick, memberships)
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if !me {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateNick()
+ })
+ uc.updateMonitor()
+ }
+ case "SETNAME":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newRealname string
+ if err := parseMessageParams(msg, &newRealname); err != nil {
+ return err
+ }
+
+ // TODO: consider appending this message to logs
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname)
+ uc.realname = newRealname
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateRealname()
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case "JOIN":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("joined channel %q", ch)
+ members := membershipsCasemapMap{newCasemapMap(0)}
+ members.casemap = uc.network.casemap
+ uc.channels.SetValue(ch, &upstreamChannel{
+ Name: ch,
+ conn: uc,
+ Members: members,
+ })
+ uc.updateChannelAutoDetach(ch)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MODE",
+ Params: []string{ch},
+ })
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.SetValue(msg.Prefix.Name, &memberships{})
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "PART":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("parted channel %q", ch)
+ uch := uc.channels.Value(ch)
+ if uch != nil {
+ uc.channels.Delete(ch)
+ uch.updateAutoDetach(0)
+ }
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(msg.Prefix.Name)
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "KICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channel, user string
+ if err := parseMessageParams(msg, &channel, &user); err != nil {
+ return err
+ }
+
+ if uc.isOurNick(user) {
+ uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
+ uc.channels.Delete(channel)
+ } else {
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(user)
+ }
+
+ uc.produce(channel, msg, nil)
+ case "QUIT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("quit")
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if ch.Members.Has(msg.Prefix.Name) {
+ ch.Members.Delete(msg.Prefix.Name)
+
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if msg.Prefix.Name != uc.nick {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
+ var name, topic string
+ if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if msg.Command == irc.RPL_TOPIC {
+ ch.Topic = topic
+ } else {
+ ch.Topic = ""
+ }
+ case "TOPIC":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ ch.Topic = msg.Params[1]
+ ch.TopicWho = msg.Prefix.Copy()
+ ch.TopicTime = time.Now() // TODO use msg.Tags["time"]
+ } else {
+ ch.Topic = ""
+ }
+ uc.produce(ch.Name, msg, nil)
+ case "MODE":
+ var name, modeStr string
+ if err := parseMessageParams(msg, &name, &modeStr); err != nil {
+ return err
+ }
+
+ if !uc.isChannel(name) { // user mode change
+ if name != uc.nick {
+ return fmt.Errorf("received MODE message for unknown nick %q", name)
+ }
+
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ } else { // channel mode change
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+
+ needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:])
+ if err != nil {
+ return err
+ }
+
+ uc.appendLog(ch.Name, msg)
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := make([]string, len(msg.Params))
+ params[0] = dc.marshalEntity(uc.network, name)
+ params[1] = modeStr
+
+ copy(params[2:], msg.Params[2:])
+ for i, modeParam := range params[2:] {
+ if _, ok := needMarshaling[i]; ok {
+ params[2+i] = dc.marshalEntity(uc.network, modeParam)
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "MODE",
+ Params: params,
+ })
+ })
+ }
+ }
+ case irc.RPL_UMODEIS:
+ if err := parseMessageParams(msg, nil); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ uc.modes = ""
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ case irc.RPL_CHANNELMODEIS:
+ var channel string
+ if err := parseMessageParams(msg, nil, &channel); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 2 {
+ modeStr = msg.Params[2]
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstMode := ch.modes == nil
+ ch.modes = make(map[byte]string)
+ if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil {
+ return err
+ }
+
+ c := uc.network.channels.Value(channel)
+ if firstMode && (c == nil || !c.Detached) {
+ modeStr, modeParams := ch.modes.Format()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ })
+ }
+ case rpl_creationtime:
+ var channel, creationTime string
+ if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstCreationTime := ch.creationTime == ""
+ ch.creationTime = creationTime
+
+ c := uc.network.channels.Value(channel)
+ if firstCreationTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime},
+ })
+ })
+ }
+ case rpl_topicwhotime:
+ var channel, who, timeStr string
+ if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstTopicWhoTime := ch.TopicWho == nil
+ ch.TopicWho = irc.ParsePrefix(who)
+ sec, err := strconv.ParseInt(timeStr, 10, 64)
+ if err != nil {
+ return fmt.Errorf("failed to parse topic time: %v", err)
+ }
+ ch.TopicTime = time.Unix(sec, 0)
+
+ c := uc.network.channels.Value(channel)
+ if firstTopicWhoTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{
+ dc.nick,
+ dc.marshalEntity(uc.network, ch.Name),
+ topicWho.String(),
+ timeStr,
+ },
+ })
+ })
+ }
+ case irc.RPL_LIST:
+ var channel, clients, topic string
+ if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LIST,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic},
+ })
+ case irc.RPL_LISTEND:
+ dc, cmd := uc.dequeueCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "End of /LIST"},
+ })
+ case irc.RPL_NAMREPLY:
+ var name, statusStr, members string
+ if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+ members := splitSpace(members)
+ for i, member := range members {
+ memberships, nick := uc.parseMembershipPrefix(member)
+ members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick)
+ }
+ memberStr := strings.Join(members, " ")
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, statusStr, channel, memberStr},
+ })
+ })
+ return nil
+ }
+
+ status, err := parseChannelStatus(statusStr)
+ if err != nil {
+ return err
+ }
+ ch.Status = status
+
+ for _, s := range splitSpace(members) {
+ memberships, nick := uc.parseMembershipPrefix(s)
+ ch.Members.SetValue(nick, memberships)
+ }
+ case irc.RPL_ENDOFNAMES:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, channel, "End of /NAMES list"},
+ })
+ })
+ return nil
+ }
+
+ if ch.complete {
+ return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
+ }
+ ch.complete = true
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ forwardChannel(ctx, dc, ch)
+ })
+ }
+ case irc.RPL_WHOREPLY:
+ var channel, username, host, server, nick, flags, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ if channel != "*" {
+ channel = dc.marshalEntity(uc.network, channel)
+ }
+ nick = dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing},
+ })
+ case rpl_whospcrpl:
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ // Only supported in single-upstream mode, so forward as-is
+ dc.SendMessage(msg)
+ case irc.RPL_ENDOFWHO:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.dequeueCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ mask := "*"
+ if len(cmd.Params) > 0 {
+ mask = cmd.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "End of /WHO list"},
+ })
+ case irc.RPL_WHOISUSER:
+ var nick, username, host, realname string
+ if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, nick, username, host, "*", realname},
+ })
+ })
+ case irc.RPL_WHOISSERVER:
+ var nick, server, serverInfo string
+ if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, nick, server, serverInfo},
+ })
+ })
+ case irc.RPL_WHOISOPERATOR:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, nick, "is an IRC operator"},
+ })
+ })
+ case irc.RPL_WHOISIDLE:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ params := []string{dc.nick, nick}
+ params = append(params, msg.Params[2:]...)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISIDLE,
+ Params: params,
+ })
+ })
+ case irc.RPL_WHOISCHANNELS:
+ var nick, channelList string
+ if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
+ return err
+ }
+ channels := splitSpace(channelList)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ channelList := make([]string, len(channels))
+ for i, channel := range channels {
+ prefix, channel := uc.parseMembershipPrefix(channel)
+ channel = dc.marshalEntity(uc.network, channel)
+ channelList[i] = prefix.Format(dc) + channel
+ }
+ channels := strings.Join(channelList, " ")
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISCHANNELS,
+ Params: []string{dc.nick, nick, channels},
+ })
+ })
+ case irc.RPL_ENDOFWHOIS:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, nick, "End of /WHOIS list"},
+ })
+ })
+ case "INVITE":
+ var nick, channel string
+ if err := parseMessageParams(msg, &nick, &channel); err != nil {
+ return err
+ }
+
+ weAreInvited := uc.isOurNick(nick)
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !weAreInvited && !dc.caps["invite-notify"] {
+ return
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "INVITE",
+ Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_INVITING:
+ var nick, channel string
+ if err := parseMessageParams(msg, nil, &nick, &channel); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_INVITING,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE:
+ var targetsStr string
+ if err := parseMessageParams(msg, nil, &targetsStr); err != nil {
+ return err
+ }
+ targets := strings.Split(targetsStr, ",")
+
+ online := msg.Command == irc.RPL_MONONLINE
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ uc.monitored.SetValue(prefix.Name, online)
+ }
+
+ // Check if the nick we want is now free
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if !online && uc.nickCM != wantNickCM {
+ found := false
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if uc.network.casemap(prefix.Name) == wantNickCM {
+ found = true
+ break
+ }
+ }
+ if found {
+ uc.logger.Printf("desired nick %q is now available", wantNick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{wantNick},
+ })
+ }
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if dc.monitored.Has(prefix.Name) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ })
+ case irc.ERR_MONLISTFULL:
+ var limit, targetsStr string
+ if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil {
+ return err
+ }
+
+ targets := strings.Split(targetsStr, ",")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ if dc.monitored.Has(target) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, limit, target},
+ })
+ }
+ }
+ })
+ case irc.RPL_AWAY:
+ var nick, reason string
+ if err := parseMessageParams(msg, nil, &nick, &reason); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_AWAY,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason},
+ })
+ })
+ case "AWAY", "ACCOUNT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST:
+ var channel, mask string
+ if err := parseMessageParams(msg, nil, &channel, &mask); err != nil {
+ return err
+ }
+ var addNick, addTime string
+ if len(msg.Params) >= 5 {
+ addNick = msg.Params[3]
+ addTime = msg.Params[4]
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, channel)
+
+ var params []string
+ if addNick != "" && addTime != "" {
+ addNick := dc.marshalEntity(uc.network, addNick)
+ params = []string{dc.nick, channel, mask, addNick, addTime}
+ } else {
+ params = []string{dc.nick, channel, mask}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST:
+ var channel, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ upstreamChannel := dc.marshalEntity(uc.network, channel)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, upstreamChannel, trailing},
+ })
+ })
+ case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN:
+ var command, reason string
+ if err := parseMessageParams(msg, nil, &command, &reason); err != nil {
+ return err
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, command, reason},
+ })
+ })
+ case "FAIL":
+ var command, code string
+ if err := parseMessageParams(msg, &command, &code); err != nil {
+ return err
+ }
+
+ if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" {
+ return registrationError{msg}
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case "ACK":
+ // Ignore
+ case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
+ // Ignore
+ case irc.RPL_YOURHOST, irc.RPL_CREATED:
+ // Ignore
+ case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
+ fallthrough
+ case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
+ fallthrough
+ case rpl_localusers, rpl_globalusers:
+ fallthrough
+ case irc.RPL_MOTDSTART, irc.RPL_MOTD:
+ // Ignore these messages if they're part of the initial registration
+ // message burst. Forward them if the user explicitly asked for them.
+ if !uc.gotMotd {
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_LISTSTART:
+ // Ignore
+ case "ERROR":
+ var text string
+ if err := parseMessageParams(msg, &text); err != nil {
+ return err
+ }
+ return fmt.Errorf("fatal server error: %v", text)
+ case irc.ERR_NICKNAMEINUSE:
+ // At this point, we haven't received ISUPPORT so we don't know the
+ // maximum nickname length or whether the server supports MONITOR. Many
+ // servers have NICKLEN=30 so let's just use that.
+ if !uc.registered && len(uc.nick)+1 < 30 {
+ uc.nick = uc.nick + "_"
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ return nil
+ }
+ fallthrough
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP:
+ if !uc.registered {
+ return registrationError{msg}
+ }
+ fallthrough
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ // best effort marshaling for unknown messages, replies and errors:
+ // most numerics start with the user nick, marshal it if that's the case
+ // otherwise, conservately keep the params without marshaling
+ params := msg.Params
+ if _, err := strconv.Atoi(msg.Command); err == nil { // numeric
+ if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) {
+ params[0] = dc.nick
+ }
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ }
+ return nil
+}
+
+func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) {
+ if uc.network.detachedMessageNeedsRelay(ch, msg) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.relayDetachedMessage(uc.network, msg)
+ })
+ }
+ if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
+ uc.network.attach(ctx, ch)
+ if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
+ }
+ }
+}
+
+func (uc *upstreamConn) handleChanModes(s string) error {
+ parts := strings.SplitN(s, ",", 5)
+ if len(parts) < 4 {
+ return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s)
+ }
+ modes := make(map[byte]channelModeType)
+ for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
+ for j := 0; j < len(parts[i]); j++ {
+ mode := parts[i][j]
+ modes[mode] = mt
+ }
+ }
+ uc.availableChannelModes = modes
+ return nil
+}
+
+func (uc *upstreamConn) handleMemberships(s string) error {
+ if s == "" {
+ uc.availableMemberships = nil
+ return nil
+ }
+
+ if s[0] != '(' {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ sep := strings.IndexByte(s, ')')
+ if sep < 0 || len(s) != sep*2 {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ memberships := make([]membership, len(s)/2-1)
+ for i := range memberships {
+ memberships[i] = membership{
+ Mode: s[i+1],
+ Prefix: s[sep+i+1],
+ }
+ }
+ uc.availableMemberships = memberships
+ return nil
+}
+
+func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
+ caps := strings.Fields(capsStr)
+ for _, s := range caps {
+ kv := strings.SplitN(s, "=", 2)
+ k := strings.ToLower(kv[0])
+ var v string
+ if len(kv) == 2 {
+ v = kv[1]
+ }
+ uc.supportedCaps[k] = v
+ }
+}
+
+func (uc *upstreamConn) requestCaps() {
+ var requestCaps []string
+ for c := range permanentUpstreamCaps {
+ if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] {
+ requestCaps = append(requestCaps, c)
+ }
+ }
+
+ if len(requestCaps) == 0 {
+ return
+ }
+
+ uc.SendMessage(context.TODO(), &irc.Message{
+ Command: "CAP",
+ Params: []string{"REQ", strings.Join(requestCaps, " ")},
+ })
+}
+
+func (uc *upstreamConn) supportsSASL(mech string) bool {
+ v, ok := uc.supportedCaps["sasl"]
+ if !ok {
+ return false
+ }
+
+ if v == "" {
+ return true
+ }
+
+ mechanisms := strings.Split(v, ",")
+ for _, mech := range mechanisms {
+ if strings.EqualFold(mech, mech) {
+ return true
+ }
+ }
+ return false
+}
+
+func (uc *upstreamConn) requestSASL() bool {
+ if uc.network.SASL.Mechanism == "" {
+ return false
+ }
+ return uc.supportsSASL(uc.network.SASL.Mechanism)
+}
+
+func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error {
+ uc.caps[name] = ok
+
+ switch name {
+ case "sasl":
+ if !uc.requestSASL() {
+ return nil
+ }
+ if !ok {
+ uc.logger.Printf("server refused to acknowledge the SASL capability")
+ return nil
+ }
+
+ auth := &uc.network.SASL
+ switch auth.Mechanism {
+ case "PLAIN":
+ uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
+ uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
+ case "EXTERNAL":
+ uc.logger.Printf("starting SASL EXTERNAL authentication")
+ uc.saslClient = sasl.NewExternalClient("")
+ default:
+ return fmt.Errorf("unsupported SASL mechanism %q", name)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{auth.Mechanism},
+ })
+ default:
+ if permanentUpstreamCaps[name] {
+ break
+ }
+ uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
+ }
+ return nil
+}
+
+func splitSpace(s string) []string {
+ return strings.FieldsFunc(s, func(r rune) bool {
+ return r == ' '
+ })
+}
+
+func (uc *upstreamConn) register(ctx context.Context) {
+ uc.nick = GetNick(&uc.user.User, &uc.network.Network)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.username = GetUsername(&uc.user.User, &uc.network.Network)
+ uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"LS", "302"},
+ })
+
+ if uc.network.Pass != "" {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PASS",
+ Params: []string{uc.network.Pass},
+ })
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "USER",
+ Params: []string{uc.username, "0", "*", uc.realname},
+ })
+}
+
+func (uc *upstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := uc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error {
+ for !uc.registered {
+ msg, err := uc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read message: %v", err)
+ }
+
+ if err := uc.handleMessage(ctx, msg); err != nil {
+ if _, ok := err.(registrationError); ok {
+ return err
+ } else {
+ msg.Tags = nil // prevent message tags from cluttering logs
+ return fmt.Errorf("failed to handle message %q: %v", msg, err)
+ }
+ }
+ }
+
+ for _, command := range uc.network.ConnectCommands {
+ m, err := irc.ParseMessage(command)
+ if err != nil {
+ uc.logger.Printf("failed to parse connect command %q: %v", command, err)
+ } else {
+ uc.SendMessage(ctx, m)
+ }
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := uc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventUpstreamMessage{msg, uc}
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
+ if !uc.caps["message-tags"] {
+ msg = msg.Copy()
+ msg.Tags = nil
+ }
+
+ uc.conn.SendMessage(ctx, msg)
+}
+
+func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
+ if uc.caps["labeled-response"] {
+ if msg.Tags == nil {
+ msg.Tags = make(map[string]irc.TagValue)
+ }
+ msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
+ uc.nextLabelID++
+ }
+ uc.SendMessage(ctx, msg)
+}
+
+// appendLog appends a message to the log file.
+//
+// The internal message ID is returned. If the message isn't recorded in the
+// log file, an empty string is returned.
+func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) {
+ if uc.user.msgStore == nil {
+ return ""
+ }
+
+ // Don't store messages with a server mask target
+ if strings.HasPrefix(entity, "$") {
+ return ""
+ }
+
+ entityCM := uc.network.casemap(entity)
+ if entityCM == "nickserv" {
+ // The messages sent/received from NickServ may contain
+ // security-related information (like passwords). Don't store these.
+ return ""
+ }
+
+ if !uc.network.delivered.HasTarget(entity) {
+ // This is the first message we receive from this target. Save the last
+ // message ID in delivery receipts, so that we can send the new message
+ // in the backlog if an offline client reconnects.
+ lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now())
+ if err != nil {
+ uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
+ return ""
+ }
+
+ uc.network.delivered.ForEachClient(func(clientName string) {
+ uc.network.delivered.StoreID(entity, clientName, lastID)
+ })
+ }
+
+ msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg)
+ if err != nil {
+ uc.logger.Printf("failed to append message to store: %v", err)
+ return ""
+ }
+
+ return msgID
+}
+
+// produce appends a message to the logs and forwards it to connected downstream
+// connections.
+//
+// If origin is not nil and origin doesn't support echo-message, the message is
+// forwarded to all connections except origin.
+func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) {
+ var msgID string
+ if target != "" {
+ msgID = uc.appendLog(target, msg)
+ }
+
+ // Don't forward messages if it's a detached channel
+ ch := uc.network.channels.Value(target)
+ detached := ch != nil && ch.Detached
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !detached && (dc != origin || dc.caps["echo-message"]) {
+ dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID)
+ } else {
+ dc.advanceMessageWithID(msg, msgID)
+ }
+ })
+}
+
+func (uc *upstreamConn) updateAway() {
+ ctx := context.TODO()
+
+ away := true
+ uc.forEachDownstream(func(*downstreamConn) {
+ away = false
+ })
+ if away == uc.away {
+ return
+ }
+ if away {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ Params: []string{"Auto away"},
+ })
+ } else {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ })
+ }
+ uc.away = away
+}
+
+func (uc *upstreamConn) updateChannelAutoDetach(name string) {
+ uch := uc.channels.Value(name)
+ if uch == nil {
+ return
+ }
+ ch := uc.network.channels.Value(name)
+ if ch == nil || ch.Detached {
+ return
+ }
+ uch.updateAutoDetach(ch.DetachAfter)
+}
+
+func (uc *upstreamConn) updateMonitor() {
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return
+ }
+
+ ctx := context.TODO()
+
+ add := make(map[string]struct{})
+ var addList []string
+ seen := make(map[string]struct{})
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for targetCM := range dc.monitored.innerMap {
+ if !uc.monitored.Has(targetCM) {
+ if _, ok := add[targetCM]; !ok {
+ addList = append(addList, targetCM)
+ add[targetCM] = struct{}{}
+ }
+ } else {
+ seen[targetCM] = struct{}{}
+ }
+ }
+ })
+
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
+ addList = append(addList, wantNickCM)
+ add[wantNickCM] = struct{}{}
+ }
+
+ removeAll := true
+ var removeList []string
+ for targetCM, entry := range uc.monitored.innerMap {
+ if _, ok := seen[targetCM]; ok {
+ removeAll = false
+ } else {
+ removeList = append(removeList, entry.originalKey)
+ }
+ }
+
+ // TODO: better handle the case where len(uc.monitored) + len(addList)
+ // exceeds the limit, probably by immediately sending ERR_MONLISTFULL?
+
+ if removeAll && len(addList) == 0 && len(removeList) > 0 {
+ // Optimization when the last MONITOR-aware downstream disconnects
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{"C"},
+ })
+ } else {
+ msgs := generateMonitor("-", removeList)
+ msgs = append(msgs, generateMonitor("+", addList)...)
+ for _, msg := range msgs {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+
+ for _, target := range removeList {
+ uc.monitored.Delete(target)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/binary"
+ "encoding/hex"
+ "fmt"
+ "math/big"
+ "net"
+ "sort"
+ "strings"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+type event interface{}
+
+type eventUpstreamMessage struct {
+ msg *irc.Message
+ uc *upstreamConn
+}
+
+type eventUpstreamConnectionError struct {
+ net *network
+ err error
+}
+
+type eventUpstreamConnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamDisconnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamError struct {
+ uc *upstreamConn
+ err error
+}
+
+type eventDownstreamMessage struct {
+ msg *irc.Message
+ dc *downstreamConn
+}
+
+type eventDownstreamConnected struct {
+ dc *downstreamConn
+}
+
+type eventDownstreamDisconnected struct {
+ dc *downstreamConn
+}
+
+type eventChannelDetach struct {
+ uc *upstreamConn
+ name string
+}
+
+type eventBroadcast struct {
+ msg *irc.Message
+}
+
+type eventStop struct{}
+
+type eventUserUpdate struct {
+ password *string
+ admin *bool
+ done chan error
+}
+
+type deliveredClientMap map[string]string // client name -> msg ID
+
+type deliveredStore struct {
+ m deliveredCasemapMap
+}
+
+func newDeliveredStore() deliveredStore {
+ return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
+}
+
+func (ds deliveredStore) HasTarget(target string) bool {
+ return ds.m.Value(target) != nil
+}
+
+func (ds deliveredStore) LoadID(target, clientName string) string {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ return ""
+ }
+ return clients[clientName]
+}
+
+func (ds deliveredStore) StoreID(target, clientName, msgID string) {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ clients = make(deliveredClientMap)
+ ds.m.SetValue(target, clients)
+ }
+ clients[clientName] = msgID
+}
+
+func (ds deliveredStore) ForEachTarget(f func(target string)) {
+ for _, entry := range ds.m.innerMap {
+ f(entry.originalKey)
+ }
+}
+
+func (ds deliveredStore) ForEachClient(f func(clientName string)) {
+ clients := make(map[string]struct{})
+ for _, entry := range ds.m.innerMap {
+ delivered := entry.value.(deliveredClientMap)
+ for clientName := range delivered {
+ clients[clientName] = struct{}{}
+ }
+ }
+
+ for clientName := range clients {
+ f(clientName)
+ }
+}
+
+type network struct {
+ Network
+ user *user
+ logger Logger
+ stopped chan struct{}
+
+ conn *upstreamConn
+ channels channelCasemapMap
+ delivered deliveredStore
+ lastError error
+ casemap casemapping
+}
+
+func newNetwork(user *user, record *Network, channels []Channel) *network {
+ logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
+
+ m := channelCasemapMap{newCasemapMap(0)}
+ for _, ch := range channels {
+ ch := ch
+ m.SetValue(ch.Name, &ch)
+ }
+
+ return &network{
+ Network: *record,
+ user: user,
+ logger: logger,
+ stopped: make(chan struct{}),
+ channels: m,
+ delivered: newDeliveredStore(),
+ casemap: casemapRFC1459,
+ }
+}
+
+func (net *network) forEachDownstream(f func(*downstreamConn)) {
+ net.user.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ if dc.network != nil && dc.network != net {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (net *network) isStopped() bool {
+ select {
+ case <-net.stopped:
+ return true
+ default:
+ return false
+ }
+}
+
+func userIdent(u *User) string {
+ // The ident is a string we will send to upstream servers in clear-text.
+ // For privacy reasons, make sure it doesn't expose any meaningful user
+ // metadata. We just use the base64-encoded hashed ID, so that people don't
+ // start relying on the string being an integer or following a pattern.
+ var b [64]byte
+ binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
+ h := sha256.Sum256(b[:])
+ return hex.EncodeToString(h[:16])
+}
+
+func (net *network) run() {
+ if !net.Enabled {
+ return
+ }
+
+ var lastTry time.Time
+ backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter)
+ for {
+ if net.isStopped() {
+ return
+ }
+
+ delay := backoff.Next() - time.Now().Sub(lastTry)
+ if delay > 0 {
+ net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
+ time.Sleep(delay)
+ }
+ lastTry = time.Now()
+
+
+ uc, err := connectToUpstream(context.TODO(), net)
+ if err != nil {
+ net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
+ continue
+ }
+
+ uc.register(context.TODO())
+ if err := uc.runUntilRegistered(context.TODO()); err != nil {
+ text := err.Error()
+ temp := true
+ if regErr, ok := err.(registrationError); ok {
+ text = regErr.Reason()
+ temp = regErr.Temporary()
+ }
+ uc.logger.Printf("failed to register: %v", text)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
+ uc.Close()
+ if !temp {
+ return
+ }
+ continue
+ }
+
+ // TODO: this is racy with net.stopped. If the network is stopped
+ // before the user goroutine receives eventUpstreamConnected, the
+ // connection won't be closed.
+ net.user.events <- eventUpstreamConnected{uc}
+ if err := uc.readMessages(net.user.events); err != nil {
+ uc.logger.Printf("failed to handle messages: %v", err)
+ net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
+ }
+ uc.Close()
+ net.user.events <- eventUpstreamDisconnected{uc}
+
+ backoff.Reset()
+ }
+}
+
+func (net *network) stop() {
+ if !net.isStopped() {
+ close(net.stopped)
+ }
+
+ if net.conn != nil {
+ net.conn.Close()
+ }
+}
+
+func (net *network) detach(ch *Channel) {
+ if ch.Detached {
+ return
+ }
+
+ net.logger.Printf("detaching channel %q", ch.Name)
+
+ ch.Detached = true
+
+ if net.user.msgStore != nil {
+ nameCM := net.casemap(ch.Name)
+ lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
+ if err != nil {
+ net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
+ }
+ ch.DetachedInternalMsgID = lastID
+ }
+
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "PART",
+ Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
+ })
+ })
+}
+
+func (net *network) attach(ctx context.Context, ch *Channel) {
+ if !ch.Detached {
+ return
+ }
+
+ net.logger.Printf("attaching channel %q", ch.Name)
+
+ detachedMsgID := ch.DetachedInternalMsgID
+ ch.Detached = false
+ ch.DetachedInternalMsgID = ""
+
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+
+ net.conn.updateChannelAutoDetach(ch.Name)
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(net, ch.Name)},
+ })
+
+ if uch != nil {
+ forwardChannel(ctx, dc, uch)
+ }
+
+ if detachedMsgID != "" {
+ dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID)
+ }
+ })
+}
+
+func (net *network) deleteChannel(ctx context.Context, name string) error {
+ ch := net.channels.Value(name)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
+ return err
+ }
+ net.channels.Delete(name)
+ return nil
+}
+
+func (net *network) updateCasemapping(newCasemap casemapping) {
+ net.casemap = newCasemap
+ net.channels.SetCasemapping(newCasemap)
+ net.delivered.m.SetCasemapping(newCasemap)
+ if uc := net.conn; uc != nil {
+ uc.channels.SetCasemapping(newCasemap)
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.Members.SetCasemapping(newCasemap)
+ }
+ uc.monitored.SetCasemapping(newCasemap)
+ }
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.monitored.SetCasemapping(newCasemap)
+ })
+}
+
+func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) {
+ if !net.user.hasPersistentMsgStore() {
+ return
+ }
+
+ var receipts []DeliveryReceipt
+ net.delivered.ForEachTarget(func(target string) {
+ msgID := net.delivered.LoadID(target, clientName)
+ if msgID == "" {
+ return
+ }
+ receipts = append(receipts, DeliveryReceipt{
+ Target: target,
+ InternalMsgID: msgID,
+ })
+ })
+
+ if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil {
+ net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
+ }
+}
+
+func (net *network) isHighlight(msg *irc.Message) bool {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return false
+ }
+
+ text := msg.Params[1]
+
+ nick := net.Nick
+ if net.conn != nil {
+ nick = net.conn.nick
+ }
+
+ // TODO: use case-mapping aware comparison here
+ return msg.Prefix.Name != nick && isHighlight(text, nick)
+}
+
+func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
+ highlight := net.isHighlight(msg)
+ return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
+}
+
+func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
+ // User may have e.g. EXTERNAL mechanism configured. We do not want to
+ // automatically erase the key pair or any other credentials.
+ if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
+ return
+ }
+
+ net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
+ net.SASL.Mechanism = "PLAIN"
+ net.SASL.Plain.Username = username
+ net.SASL.Plain.Password = password
+ if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
+ net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
+ }
+}
+
+type user struct {
+ User
+ srv *Server
+ logger Logger
+
+ events chan event
+ done chan struct{}
+
+ networks []*network
+ downstreamConns []*downstreamConn
+ msgStore messageStore
+}
+
+func newUser(srv *Server, record *User) *user {
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
+
+ var msgStore messageStore
+ if logPath := srv.Config().LogPath; logPath != "" {
+ msgStore = newFSMessageStore(logPath, record)
+ } else {
+ msgStore = newMemoryMessageStore()
+ }
+
+ return &user{
+ User: *record,
+ srv: srv,
+ logger: logger,
+ events: make(chan event, 64),
+ done: make(chan struct{}),
+ msgStore: msgStore,
+ }
+}
+
+func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
+ for _, network := range u.networks {
+ if network.conn == nil {
+ continue
+ }
+ f(network.conn)
+ }
+}
+
+func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
+ for _, dc := range u.downstreamConns {
+ f(dc)
+ }
+}
+
+func (u *user) getNetwork(name string) *network {
+ for _, network := range u.networks {
+ if network.Addr == name {
+ return network
+ }
+ if network.Name != "" && network.Name == name {
+ return network
+ }
+ }
+ return nil
+}
+
+func (u *user) getNetworkByID(id int64) *network {
+ for _, net := range u.networks {
+ if net.ID == id {
+ return net
+ }
+ }
+ return nil
+}
+
+func (u *user) run() {
+ defer func() {
+ if u.msgStore != nil {
+ if err := u.msgStore.Close(); err != nil {
+ u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
+ }
+ }
+ close(u.done)
+ }()
+
+ networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
+ if err != nil {
+ u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
+ return
+ }
+
+ sort.Slice(networks, func(i, j int) bool {
+ return networks[i].ID < networks[j].ID
+ })
+
+ for _, record := range networks {
+ record := record
+ channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
+ continue
+ }
+
+ network := newNetwork(u, &record, channels)
+ u.networks = append(u.networks, network)
+
+ if u.hasPersistentMsgStore() {
+ receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
+ return
+ }
+
+ for _, rcpt := range receipts {
+ network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
+ }
+ }
+
+ go network.run()
+ }
+
+ for e := range u.events {
+ switch e := e.(type) {
+ case eventUpstreamConnected:
+ uc := e.uc
+
+ uc.network.conn = uc
+
+ uc.updateAway()
+ uc.updateMonitor()
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+ })
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=connected"},
+ })
+ }
+ })
+ uc.network.lastError = nil
+ case eventUpstreamDisconnected:
+ u.handleUpstreamDisconnected(e.uc)
+ case eventUpstreamConnectionError:
+ net := e.net
+
+ stopped := false
+ select {
+ case <-net.stopped:
+ stopped = true
+ default:
+ }
+
+ if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
+ net.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
+ })
+ }
+ net.lastError = e.err
+ case eventUpstreamError:
+ uc := e.uc
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
+ })
+ uc.network.lastError = e.err
+ case eventUpstreamMessage:
+ msg, uc := e.msg, e.uc
+ if uc.isClosed() {
+ uc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ if err := uc.handleMessage(context.TODO(), msg); err != nil {
+ uc.logger.Printf("failed to handle message %q: %v", msg, err)
+ }
+ case eventChannelDetach:
+ uc, name := e.uc, e.name
+ c := uc.network.channels.Value(name)
+ if c == nil || c.Detached {
+ continue
+ }
+ uc.network.detach(c)
+ if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
+ u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
+ }
+ case eventDownstreamConnected:
+ dc := e.dc
+
+ if dc.network != nil {
+ dc.monitored.SetCasemapping(dc.network.casemap)
+ }
+
+ if err := dc.welcome(context.TODO()); err != nil {
+ dc.logger.Printf("failed to handle new registered connection: %v", err)
+ break
+ }
+
+ u.downstreamConns = append(u.downstreamConns, dc)
+
+ dc.forEachNetwork(func(network *network) {
+ if network.lastError != nil {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
+ }
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.updateAway()
+ })
+ case eventDownstreamDisconnected:
+ dc := e.dc
+
+ for i := range u.downstreamConns {
+ if u.downstreamConns[i] == dc {
+ u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
+ break
+ }
+ }
+
+ dc.forEachNetwork(func(net *network) {
+ net.storeClientDeliveryReceipts(context.TODO(), dc.clientName)
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.cancelPendingCommandsByDownstreamID(dc.id)
+ uc.updateAway()
+ uc.updateMonitor()
+ })
+ case eventDownstreamMessage:
+ msg, dc := e.msg, e.dc
+ if dc.isClosed() {
+ dc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ err := dc.handleMessage(context.TODO(), msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ dc.logger.Printf("failed to handle message %q: %v", msg, err)
+ dc.Close()
+ }
+ case eventBroadcast:
+ msg := e.msg
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case eventUserUpdate:
+ // copy the user record because we'll mutate it
+ record := u.User
+
+ if e.password != nil {
+ record.Password = *e.password
+ }
+ if e.admin != nil {
+ record.Admin = *e.admin
+ }
+
+ e.done <- u.updateUser(context.TODO(), &record)
+
+ // If the password was updated, kill all downstream connections to
+ // force them to re-authenticate with the new credentials.
+ if e.password != nil {
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ }
+ case eventStop:
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ for _, n := range u.networks {
+ n.stop()
+
+ n.delivered.ForEachClient(func(clientName string) {
+ n.storeClientDeliveryReceipts(context.TODO(), clientName)
+ })
+ }
+ return
+ default:
+ panic(fmt.Sprintf("received unknown event type: %T", e))
+ }
+ }
+}
+
+func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
+ uc.network.conn = nil
+
+ uc.abortPendingCommands()
+
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.updateAutoDetach(0)
+ }
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+
+ // If the network has been removed, don't send a state change notification
+ found := false
+ for _, net := range u.networks {
+ if net == uc.network {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return
+ }
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=disconnected"},
+ })
+ }
+ })
+
+ if uc.network.lastError == nil {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
+ }
+ })
+ }
+}
+
+func (u *user) addNetwork(network *network) {
+ u.networks = append(u.networks, network)
+
+ sort.Slice(u.networks, func(i, j int) bool {
+ return u.networks[i].ID < u.networks[j].ID
+ })
+
+ go network.run()
+}
+
+func (u *user) removeNetwork(network *network) {
+ network.stop()
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.Close()
+ }
+ })
+
+ for i, net := range u.networks {
+ if net == network {
+ u.networks = append(u.networks[:i], u.networks[i+1:]...)
+ return
+ }
+ }
+
+ panic("tried to remove a non-existing network")
+}
+
+func (u *user) checkNetwork(record *Network) error {
+ url, err := record.URL()
+ if err != nil {
+ return err
+ }
+ if url.User != nil {
+ return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme)
+ }
+ if url.RawQuery != "" {
+ return fmt.Errorf("%v:// URL must not have query values", url.Scheme)
+ }
+ if url.Fragment != "" {
+ return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme)
+ }
+ switch url.Scheme {
+ case "ircs", "irc":
+ if url.Host == "" {
+ return fmt.Errorf("%v:// URL must have a host", url.Scheme)
+ }
+ if url.Path != "" {
+ return fmt.Errorf("%v:// URL must not have a path", url.Scheme)
+ }
+ case "irc+unix", "unix":
+ if url.Host != "" {
+ return fmt.Errorf("%v:// URL must not have a host", url.Scheme)
+ }
+ if url.Path == "" {
+ return fmt.Errorf("%v:// URL must have a path", url.Scheme)
+ }
+ default:
+ return fmt.Errorf("unknown URL scheme %q", url.Scheme)
+ }
+
+ if record.GetName() == "" {
+ return fmt.Errorf("network name cannot be empty")
+ }
+ if strings.HasPrefix(record.GetName(), "-") {
+ // Can be mixed up with flags when sending commands to the service
+ return fmt.Errorf("network name cannot start with a dash character")
+ }
+
+ for _, net := range u.networks {
+ if net.GetName() == record.GetName() && net.ID != record.ID {
+ return fmt.Errorf("a network with the name %q already exists", record.GetName())
+ }
+ }
+
+ return nil
+}
+
+func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID != 0 {
+ panic("tried creating an already-existing network")
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
+ return nil, fmt.Errorf("maximum number of networks reached")
+ }
+
+ network := newNetwork(u, record, nil)
+ err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
+ if err != nil {
+ return nil, err
+ }
+
+ u.addNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return network, nil
+}
+
+func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID == 0 {
+ panic("tried updating a new network")
+ }
+
+ // If the realname is reset to the default, just wipe the per-network
+ // setting
+ if record.Realname == u.Realname {
+ record.Realname = ""
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ network := u.getNetworkByID(record.ID)
+ if network == nil {
+ panic("tried updating a non-existing network")
+ }
+
+ if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
+ return nil, err
+ }
+
+ // Most network changes require us to re-connect to the upstream server
+
+ channels := make([]Channel, 0, network.channels.Len())
+ for _, entry := range network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, *ch)
+ }
+
+ updatedNetwork := newNetwork(u, record, channels)
+
+ // If we're currently connected, disconnect and perform the necessary
+ // bookkeeping
+ if network.conn != nil {
+ network.stop()
+ // Note: this will set network.conn to nil
+ u.handleUpstreamDisconnected(network.conn)
+ }
+
+ // Patch downstream connections to use our fresh updated network
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.network = updatedNetwork
+ }
+ })
+
+ // We need to remove the network after patching downstream connections,
+ // otherwise they'll get closed
+ u.removeNetwork(network)
+
+ // The filesystem message store needs to be notified whenever the network
+ // is renamed
+ fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
+ if isFS && updatedNetwork.GetName() != network.GetName() {
+ if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
+ network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
+ }
+ }
+
+ // This will re-connect to the upstream server
+ u.addNetwork(updatedNetwork)
+
+ // TODO: only broadcast attributes that have changed
+ idStr := fmt.Sprintf("%v", updatedNetwork.ID)
+ attrs := getNetworkAttrs(updatedNetwork)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return updatedNetwork, nil
+}
+
+func (u *user) deleteNetwork(ctx context.Context, id int64) error {
+ network := u.getNetworkByID(id)
+ if network == nil {
+ panic("tried deleting a non-existing network")
+ }
+
+ if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
+ return err
+ }
+
+ u.removeNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, "*"},
+ })
+ }
+ })
+
+ return nil
+}
+
+func (u *user) updateUser(ctx context.Context, record *User) error {
+ if u.ID != record.ID {
+ panic("ID mismatch when updating user")
+ }
+
+ realnameUpdated := u.Realname != record.Realname
+ if err := u.srv.db.StoreUser(ctx, record); err != nil {
+ return fmt.Errorf("failed to update user %q: %v", u.Username, err)
+ }
+ u.User = *record
+
+ if realnameUpdated {
+ // Re-connect to networks which use the default realname
+ var needUpdate []Network
+ for _, net := range u.networks {
+ if net.Realname == "" {
+ needUpdate = append(needUpdate, net.Network)
+ }
+ }
+
+ var netErr error
+ for _, net := range needUpdate {
+ if _, err := u.updateNetwork(ctx, &net); err != nil {
+ netErr = err
+ }
+ }
+ if netErr != nil {
+ return netErr
+ }
+ }
+
+ return nil
+}
+
+func (u *user) stop() {
+ u.events <- eventStop{}
+ <-u.done
+}
+
+func (u *user) hasPersistentMsgStore() bool {
+ if u.msgStore == nil {
+ return false
+ }
+ _, isMem := u.msgStore.(*memoryMessageStore)
+ return !isMem
+}
+
+// localAddrForHost returns the local address to use when connecting to host.
+// A nil address is returned when the OS should automatically pick one.
+func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
+ upstreamUserIPs := u.srv.Config().UpstreamUserIPs
+ if len(upstreamUserIPs) == 0 {
+ return nil, nil
+ }
+
+ ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
+ if err != nil {
+ return nil, err
+ }
+
+ wantIPv6 := false
+ for _, ip := range ips {
+ if ip.To4() == nil {
+ wantIPv6 = true
+ break
+ }
+ }
+
+ var ipNet *net.IPNet
+ for _, in := range upstreamUserIPs {
+ if wantIPv6 == (in.IP.To4() == nil) {
+ ipNet = in
+ break
+ }
+ }
+ if ipNet == nil {
+ return nil, nil
+ }
+
+ var ipInt big.Int
+ ipInt.SetBytes(ipNet.IP)
+ ipInt.Add(&ipInt, big.NewInt(u.ID+1))
+ ip := net.IP(ipInt.Bytes())
+ if !ipNet.Contains(ip) {
+ return nil, fmt.Errorf("IP network %v too small", ipNet)
+ }
+
+ return &net.TCPAddr{IP: ip}, nil
+}
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "runtime/debug"
+ "strings"
+)
+
+const (
+ defaultVersion = "0.0.0"
+ defaultCommit = "HEAD"
+ defaultBuild = "0000-01-01:00:00+00:00"
+)
+
+var (
+ // Version is the tagged release version in the form <major>.<minor>.<patch>
+ // following semantic versioning and is overwritten by the build system.
+ Version = defaultVersion
+
+ // Commit is the commit sha of the build (normally from Git) and is overwritten
+ // by the build system.
+ Commit = defaultCommit
+
+ // Build is the date and time of the build as an RFC3339 formatted string
+ // and is overwritten by the build system.
+ Build = defaultBuild
+)
+
+// FullVersion display the full version and build
+func FullVersion() string {
+ var sb strings.Builder
+
+ isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild
+
+ if !isDefault {
+ sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build))
+ }
+
+ if info, ok := debug.ReadBuildInfo(); ok {
+ if isDefault {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Version))
+ }
+ sb.WriteString(fmt.Sprintf(" %s", info.GoVersion))
+ if info.Main.Sum != "" {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum))
+ }
+ }
+
+ return sb.String()
+}
--- /dev/null
+vendor
+/suika
+/suikadb
+/suika-znc-import
+/suika.db
--- /dev/null
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+ <one line to give the program's name and a brief idea of what it does.>
+ Copyright (C) <year> <name of author>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+<https://www.gnu.org/licenses/>.
--- /dev/null
+GO ?= go
+RM ?= rm
+GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor
+PREFIX ?= /usr/local
+BINDIR ?= bin
+MANDIR ?= share/man
+MKDIR ?= mkdir
+CP ?= cp
+SYSCONFDIR ?= /etc
+ASCIIDOCTOR ?= asciidoctor
+
+VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"`
+COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"`
+BRANCH = `git rev-parse --abbrev-ref HEAD`
+BUILD = `git show -s --pretty=format:%cI`
+
+GOARCH ?= amd64
+GOOS ?= linux
+
+all: build
+
+build: vendor
+ ${GO} build ${GOFLAGS} ./cmd/suika
+ ${GO} build ${GOFLAGS} ./cmd/suikadb
+ ${GO} build ${GOFLAGS} ./cmd/suika-znc-import
+clean:
+ ${RM} -f suika suikadb suika-znc-import
+install:
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR}
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7
+ ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika
+ ${MKDIR} -p ${DESTDIR}/var/lib/suika
+ ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR}
+ ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1
+ ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config
+test:
+ go test
+vendor:
+ go mod vendor
+.PHONY: build clean install
--- /dev/null
+# suika
+
+[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika)
+
+A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power)
+
+- Multi-user
+- Support multiple clients for a single user, with proper backlog
+ synchronization
+- Support connecting to multiple upstream servers via a single IRC connection
+ to the bouncer
+
+## Building and installing
+
+Dependencies:
+
+- Go
+- BSD or GNU make
+
+For end users, a `Makefile` is provided:
+
+ make
+ doas make install
+
+For development, you can use `go run ./cmd/suika` as usual.
+
+## License
+AGPLv3, see [LICENSE](LICENSE).
+
+* Copyright (C) 2020 The soju Contributors
+* Copyright (C) 2023-present Izuru Yakumo
+
+The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gopkg.in/irc.v3"
+)
+
+func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) {
+ if !ch.complete {
+ panic("Tried to forward a partial channel")
+ }
+
+ // RPL_NOTOPIC shouldn't be sent on JOIN
+ if ch.Topic != "" {
+ sendTopic(dc, ch)
+ }
+
+ if dc.caps["soju.im/read"] {
+ channelCM := ch.conn.network.casemap(ch.Name)
+ r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err)
+ } else {
+ timestampStr := "*"
+ if r != nil {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr},
+ })
+ }
+ }
+
+ sendNames(dc, ch)
+}
+
+func sendTopic(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ if ch.Topic != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_TOPIC,
+ Params: []string{dc.nick, downstreamName, ch.Topic},
+ })
+ if ch.TopicWho != nil {
+ topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho)
+ topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime},
+ })
+ }
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NOTOPIC,
+ Params: []string{dc.nick, downstreamName, "No topic is set"},
+ })
+ }
+}
+
+func sendNames(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ emptyNameReply := &irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, ""},
+ }
+ maxLength := maxMessageLength - len(emptyNameReply.String())
+
+ var buf strings.Builder
+ for _, entry := range ch.Members.innerMap {
+ nick := entry.originalKey
+ memberships := entry.value.(*memberships)
+ s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick)
+
+ n := buf.Len() + 1 + len(s)
+ if buf.Len() != 0 && n > maxLength {
+ // There's not enough space for the next space + nick.
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ buf.Reset()
+ }
+
+ if buf.Len() != 0 {
+ buf.WriteByte(' ')
+ }
+ buf.WriteString(s)
+ }
+
+ if buf.Len() != 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, downstreamName, "End of /NAMES list"},
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/ed25519"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "math/big"
+ "time"
+)
+
+func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) {
+ var (
+ privKey crypto.PrivateKey
+ pubKey crypto.PublicKey
+ )
+ switch keyType {
+ case "rsa":
+ key, err := rsa.GenerateKey(rand.Reader, bits)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ecdsa":
+ key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ed25519":
+ var err error
+ pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ // Using PKCS#8 allows easier extension for new key types.
+ privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ notBefore := time.Now()
+ // Lets make a fair assumption nobody will use the same cert for more than 20 years...
+ notAfter := notBefore.Add(24 * time.Hour * 365 * 20)
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ return nil, nil, err
+ }
+ cert := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{CommonName: "suika auto-generated certificate"},
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
+ certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return privKeyBytes, certBytes, nil
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/url"
+ "os"
+ "strings"
+ "unicode"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+const usage = `usage: suika-znc-import [options...] <znc config path>
+
+Imports configuration from a ZNC file. Users and networks are merged if they
+already exist in the suika database. ZNC settings overwrite existing suika
+settings.
+
+Options:
+
+ -help Show this help message
+ -config <path> Path to suika config file
+ -user <username> Limit import to username (may be specified multiple times)
+ -network <name> Limit import to network (may be specified multiple times)
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ users := make(map[string]bool)
+ networks := make(map[string]bool)
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Var((*stringSetFlag)(&users), "user", "")
+ flag.Var((*stringSetFlag)(&networks), "network", "")
+ flag.Parse()
+
+ zncPath := flag.Arg(0)
+ if zncPath == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ ctx := context.Background()
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+ defer db.Close()
+
+ f, err := os.Open(zncPath)
+ if err != nil {
+ log.Fatalf("failed to open ZNC configuration file: %v", err)
+ }
+ defer f.Close()
+
+ zp := zncParser{bufio.NewReader(f), 1}
+ root, err := zp.sectionBody("", "")
+ if err != nil {
+ log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
+ }
+
+ l, err := db.ListUsers(ctx)
+ if err != nil {
+ log.Fatalf("failed to list users in DB: %v", err)
+ }
+ existingUsers := make(map[string]*suika.User, len(l))
+ for i, u := range l {
+ existingUsers[u.Username] = &l[i]
+ }
+
+ usersCreated := 0
+ usersImported := 0
+ networksImported := 0
+ channelsImported := 0
+ root.ForEach("User", func(section *zncSection) {
+ username := section.Name
+ if len(users) > 0 && !users[username] {
+ return
+ }
+ usersImported++
+
+ u, ok := existingUsers[username]
+ if ok {
+ log.Printf("user %q: updating existing user", username)
+ } else {
+ // "!!" is an invalid crypt format, thus disables password auth
+ u = &suika.User{Username: username, Password: "!!"}
+ usersCreated++
+ log.Printf("user %q: creating new user", username)
+ }
+
+ u.Admin = section.Values.Get("Admin") == "true"
+
+ if err := db.StoreUser(ctx, u); err != nil {
+ log.Fatalf("failed to store user %q: %v", username, err)
+ }
+ userID := u.ID
+
+ l, err := db.ListNetworks(ctx, userID)
+ if err != nil {
+ log.Fatalf("failed to list networks for user %q: %v", username, err)
+ }
+ existingNetworks := make(map[string]*suika.Network, len(l))
+ for i, n := range l {
+ existingNetworks[n.GetName()] = &l[i]
+ }
+
+ nick := section.Values.Get("Nick")
+ realname := section.Values.Get("RealName")
+ ident := section.Values.Get("Ident")
+
+ section.ForEach("Network", func(section *zncSection) {
+ netName := section.Name
+ if len(networks) > 0 && !networks[netName] {
+ return
+ }
+ networksImported++
+
+ logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
+ logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
+
+ netNick := section.Values.Get("Nick")
+ if netNick == "" {
+ netNick = nick
+ }
+ netRealname := section.Values.Get("RealName")
+ if netRealname == "" {
+ netRealname = realname
+ }
+ netIdent := section.Values.Get("Ident")
+ if netIdent == "" {
+ netIdent = ident
+ }
+
+ for _, name := range section.Values["LoadModule"] {
+ switch name {
+ case "sasl":
+ logger.Printf("warning: SASL credentials not imported")
+ case "nickserv":
+ logger.Printf("warning: NickServ credentials not imported")
+ case "perform":
+ logger.Printf("warning: \"perform\" plugin commands not imported")
+ }
+ }
+
+ u, pass, err := importNetworkServer(section.Values.Get("Server"))
+ if err != nil {
+ logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
+ }
+
+ n, ok := existingNetworks[netName]
+ if ok {
+ logger.Printf("updating existing network")
+ } else {
+ n = &suika.Network{Name: netName}
+ logger.Printf("creating new network")
+ }
+
+ n.Addr = u.String()
+ n.Nick = netNick
+ n.Username = netIdent
+ n.Realname = netRealname
+ n.Pass = pass
+ n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
+
+ if err := db.StoreNetwork(ctx, userID, n); err != nil {
+ logger.Fatalf("failed to store network: %v", err)
+ }
+
+ l, err := db.ListChannels(ctx, n.ID)
+ if err != nil {
+ logger.Fatalf("failed to list channels: %v", err)
+ }
+ existingChannels := make(map[string]*suika.Channel, len(l))
+ for i, ch := range l {
+ existingChannels[ch.Name] = &l[i]
+ }
+
+ section.ForEach("Chan", func(section *zncSection) {
+ chName := section.Name
+
+ if section.Values.Get("Disabled") == "true" {
+ logger.Printf("skipping import of disabled channel %q", chName)
+ return
+ }
+
+ channelsImported++
+
+ ch, ok := existingChannels[chName]
+ if ok {
+ logger.Printf("channel %q: updating existing channel", chName)
+ } else {
+ ch = &suika.Channel{Name: chName}
+ logger.Printf("channel %q: creating new channel", chName)
+ }
+
+ ch.Key = section.Values.Get("Key")
+ ch.Detached = section.Values.Get("Detached") == "true"
+
+ if err := db.StoreChannel(ctx, n.ID, ch); err != nil {
+ logger.Printf("channel %q: failed to store channel: %v", chName, err)
+ }
+ })
+ })
+ })
+
+ if err := db.Close(); err != nil {
+ log.Printf("failed to close database: %v", err)
+ }
+
+ if usersCreated > 0 {
+ log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password <username>`")
+ }
+
+ log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
+}
+
+func importNetworkServer(s string) (u *url.URL, pass string, err error) {
+ parts := strings.Fields(s)
+ if len(parts) < 2 {
+ return nil, "", fmt.Errorf("expected space-separated host and port")
+ }
+
+ scheme := "irc"
+ host := parts[0]
+ port := parts[1]
+ if strings.HasPrefix(port, "+") {
+ port = port[1:]
+ scheme = "ircs"
+ }
+
+ if len(parts) > 2 {
+ pass = parts[2]
+ }
+
+ u = &url.URL{
+ Scheme: scheme,
+ Host: host + ":" + port,
+ }
+ return u, pass, nil
+}
+
+type zncSection struct {
+ Type string
+ Name string
+ Values zncValues
+ Children []zncSection
+}
+
+func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
+ for _, section := range s.Children {
+ if section.Type == typ {
+ f(§ion)
+ }
+ }
+}
+
+type zncValues map[string][]string
+
+func (zv zncValues) Get(k string) string {
+ if len(zv[k]) == 0 {
+ return ""
+ }
+ return zv[k][0]
+}
+
+type zncParser struct {
+ br *bufio.Reader
+ line int
+}
+
+func (zp *zncParser) readByte() (byte, error) {
+ b, err := zp.br.ReadByte()
+ if b == '\n' {
+ zp.line++
+ }
+ return b, err
+}
+
+func (zp *zncParser) readRune() (rune, int, error) {
+ r, n, err := zp.br.ReadRune()
+ if r == '\n' {
+ zp.line++
+ }
+ return r, n, err
+}
+
+func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
+ section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
+
+Loop:
+ for {
+ if err := zp.skipSpace(); err != nil {
+ return nil, err
+ }
+
+ b, err := zp.br.Peek(2)
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ switch b[0] {
+ case '<':
+ if b[1] == '/' {
+ break Loop
+ } else {
+ childType, childName, err := zp.sectionHeader()
+ if err != nil {
+ return nil, err
+ }
+ child, err := zp.sectionBody(childType, childName)
+ if err != nil {
+ return nil, err
+ }
+ if footerType, err := zp.sectionFooter(); err != nil {
+ return nil, err
+ } else if footerType != childType {
+ return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
+ }
+ section.Children = append(section.Children, *child)
+ }
+ case '/':
+ if b[1] == '/' {
+ if err := zp.skipComment(); err != nil {
+ return nil, err
+ }
+ break
+ }
+ fallthrough
+ default:
+ k, v, err := zp.keyValuePair()
+ if err != nil {
+ return nil, err
+ }
+ section.Values[k] = append(section.Values[k], v)
+ }
+ }
+
+ return section, nil
+}
+
+func (zp *zncParser) skipSpace() error {
+ for {
+ r, _, err := zp.readRune()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if !unicode.IsSpace(r) {
+ zp.br.UnreadRune()
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) skipComment() error {
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+
+ for {
+ b, err := zp.readByte()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if b == '\n' {
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) sectionHeader() (string, string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", "", err
+ }
+ typ, err := zp.readWord(' ')
+ if err != nil {
+ return "", "", err
+ }
+ name, err := zp.readWord('>')
+ return typ, name, err
+}
+
+func (zp *zncParser) sectionFooter() (string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return "", err
+ }
+ return zp.readWord('>')
+}
+
+func (zp *zncParser) keyValuePair() (string, string, error) {
+ k, err := zp.readWord('=')
+ if err != nil {
+ return "", "", err
+ }
+ v, err := zp.readWord('\n')
+ return strings.TrimSpace(k), strings.TrimSpace(v), err
+}
+
+func (zp *zncParser) expectRune(expected rune) error {
+ r, _, err := zp.readRune()
+ if err != nil {
+ return err
+ } else if r != expected {
+ return fmt.Errorf("expected %q, got %q", expected, r)
+ }
+ return nil
+}
+
+func (zp *zncParser) readWord(delim byte) (string, error) {
+ var sb strings.Builder
+ for {
+ b, err := zp.readByte()
+ if err != nil {
+ return "", err
+ }
+
+ if b == delim {
+ return sb.String(), nil
+ }
+ if b == '\n' {
+ return "", fmt.Errorf("expected %q before newline", delim)
+ }
+
+ sb.WriteByte(b)
+ }
+}
+
+type stringSetFlag map[string]bool
+
+func (v *stringSetFlag) String() string {
+ return fmt.Sprint(map[string]bool(*v))
+}
+
+func (v *stringSetFlag) Set(s string) error {
+ (*v)[s] = true
+ return nil
+}
--- /dev/null
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/url"
+ "os"
+ "os/signal"
+ "strings"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+// TCP keep-alive interval for downstream TCP connections
+const downstreamKeepAlive = 1 * time.Hour
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+func bumpOpenedFileLimit() error {
+ var rlimit syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
+ }
+ rlimit.Cur = rlimit.Max
+ if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
+ }
+ return nil
+}
+
+var (
+ configPath string
+ debug bool
+
+ tlsCert atomic.Value // *tls.Certificate
+)
+
+func loadConfig() (*config.Server, *suika.Config, error) {
+ var raw *config.Server
+ if configPath != "" {
+ var err error
+ raw, err = config.Load(configPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load config file: %v", err)
+ }
+ } else {
+ raw = config.Defaults()
+ }
+
+ var motd string
+ if raw.MOTDPath != "" {
+ b, err := os.ReadFile(raw.MOTDPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
+ }
+ motd = strings.TrimSuffix(string(b), "\n")
+ }
+
+ if raw.TLS != nil {
+ cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
+ }
+ tlsCert.Store(&cert)
+ }
+
+ cfg := &suika.Config{
+ Hostname: raw.Hostname,
+ Title: raw.Title,
+ LogPath: raw.LogPath,
+ MaxUserNetworks: raw.MaxUserNetworks,
+ MultiUpstream: raw.MultiUpstream,
+ UpstreamUserIPs: raw.UpstreamUserIPs,
+ MOTD: motd,
+ }
+ return raw, cfg, nil
+}
+
+func main() {
+ var listen []string
+ flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.BoolVar(&debug, "debug", false, "enable debug logging")
+ flag.Parse()
+
+ cfg, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ cfg.Listen = append(cfg.Listen, listen...)
+ if len(cfg.Listen) == 0 {
+ cfg.Listen = []string{":6667"}
+ }
+
+ if err := bumpOpenedFileLimit(); err != nil {
+ log.Printf("failed to bump max number of opened files: %v", err)
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ var tlsCfg *tls.Config
+ if cfg.TLS != nil {
+ tlsCfg = &tls.Config{
+ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return tlsCert.Load().(*tls.Certificate), nil
+ },
+ }
+ }
+
+ srv := suika.NewServer(db)
+ srv.SetConfig(serverCfg)
+ srv.Logger = suika.NewLogger(log.Writer(), debug)
+
+ for _, listen := range cfg.Listen {
+ listen := listen // copy
+ listenURI := listen
+ if !strings.Contains(listenURI, ":/") {
+ // This is a raw domain name, make it an URL with an empty scheme
+ listenURI = "//" + listenURI
+ }
+ u, err := url.Parse(listenURI)
+ if err != nil {
+ log.Fatalf("failed to parse listen URI %q: %v", listen, err)
+ }
+
+ switch u.Scheme {
+ case "ircs", "":
+ if tlsCfg == nil {
+ log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
+ }
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6697"
+ }
+ ircsTLSCfg := tlsCfg.Clone()
+ ircsTLSCfg.NextProtos = []string{"irc"}
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ l, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
+ }
+ ln := tls.NewListener(l, ircsTLSCfg)
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "irc":
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6667"
+ }
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ ln, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "unix":
+ ln, err := net.Listen("unix", u.Path)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ default:
+ log.Fatalf("failed to listen on %q: unsupported scheme", listen)
+ }
+
+ log.Printf("starting suika version %v\n", suika.FullVersion())
+ log.Printf("server listening on %q", listen)
+ }
+
+ sigCh := make(chan os.Signal, 1)
+ signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
+
+ if err := srv.Start(); err != nil {
+ log.Fatal(err)
+ }
+
+ for sig := range sigCh {
+ switch sig {
+ case syscall.SIGHUP:
+ log.Print("reloading configuration")
+ _, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Printf("failed to reloading configuration: %v", err)
+ } else {
+ srv.SetConfig(serverCfg)
+ }
+ case syscall.SIGINT, syscall.SIGTERM:
+ log.Print("shutting down server")
+ srv.Shutdown()
+ return
+ }
+ }
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+ "golang.org/x/crypto/bcrypt"
+ "golang.org/x/term"
+)
+
+const usage = `usage: suikadb [-config path] <action> [options...]
+
+ create-user <username> [-admin] Create a new user
+ change-password <username> Change password for a user
+ help Show this help message
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Parse()
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ ctx := context.Background()
+
+ switch cmd := flag.Arg(0); cmd {
+ case "create-user":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ fs := flag.NewFlagSet("", flag.ExitOnError)
+ admin := fs.Bool("admin", false, "make the new user admin")
+ fs.Parse(flag.Args()[2:])
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user := suika.User{
+ Username: username,
+ Password: string(hashed),
+ Admin: *admin,
+ }
+ if err := db.StoreUser(ctx, &user); err != nil {
+ log.Fatalf("failed to create user: %v", err)
+ }
+ case "change-password":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ user, err := db.GetUser(ctx, username)
+ if err != nil {
+ log.Fatalf("failed to get user: %v", err)
+ }
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user.Password = string(hashed)
+ if err := db.StoreUser(ctx, user); err != nil {
+ log.Fatalf("failed to update password: %v", err)
+ }
+ case "version":
+ fmt.Printf("%v\n", suika.FullVersion())
+ default:
+ flag.Usage()
+ if cmd != "help" {
+ os.Exit(1)
+ }
+ }
+}
+
+func readPassword() ([]byte, error) {
+ var password []byte
+ var err error
+ fd := int(os.Stdin.Fd())
+
+ if term.IsTerminal(fd) {
+ fmt.Printf("Password: ")
+ password, err = term.ReadPassword(int(os.Stdin.Fd()))
+ if err != nil {
+ return nil, err
+ }
+ fmt.Printf("\n")
+ } else {
+ fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n")
+ // TODO: the buffering messes up repeated calls to readPassword
+ scanner := bufio.NewScanner(os.Stdin)
+ if !scanner.Scan() {
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ return nil, io.ErrUnexpectedEOF
+ }
+ password = scanner.Bytes()
+
+ if len(password) == 0 {
+ return nil, fmt.Errorf("zero length password")
+ }
+ }
+
+ return password, nil
+}
--- /dev/null
+db sqlite3 /var/lib/suika/main.db
+log fs /var/lib/suika/logs/
--- /dev/null
+package config
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+
+ "git.sr.ht/~emersion/go-scfg"
+)
+
+type TLS struct {
+ CertPath, KeyPath string
+}
+
+type Server struct {
+ Listen []string
+ TLS *TLS
+ Hostname string
+ Title string
+ MOTDPath string
+
+ SQLDriver string
+ SQLSource string
+ LogPath string
+
+ MaxUserNetworks int
+ MultiUpstream bool
+ UpstreamUserIPs []*net.IPNet
+}
+
+func Defaults() *Server {
+ hostname, err := os.Hostname()
+ if err != nil {
+ hostname = "localhost"
+ }
+ return &Server{
+ Hostname: hostname,
+ SQLDriver: "sqlite3",
+ SQLSource: "suika.db",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ }
+}
+
+func Load(path string) (*Server, error) {
+ cfg, err := scfg.Load(path)
+ if err != nil {
+ return nil, err
+ }
+ return parse(cfg)
+}
+
+func parse(cfg scfg.Block) (*Server, error) {
+ srv := Defaults()
+ for _, d := range cfg {
+ switch d.Name {
+ case "listen":
+ var uri string
+ if err := d.ParseParams(&uri); err != nil {
+ return nil, err
+ }
+ srv.Listen = append(srv.Listen, uri)
+ case "hostname":
+ if err := d.ParseParams(&srv.Hostname); err != nil {
+ return nil, err
+ }
+ case "title":
+ if err := d.ParseParams(&srv.Title); err != nil {
+ return nil, err
+ }
+ case "motd":
+ if err := d.ParseParams(&srv.MOTDPath); err != nil {
+ return nil, err
+ }
+ case "tls":
+ tls := &TLS{}
+ if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
+ return nil, err
+ }
+ srv.TLS = tls
+ case "db":
+ if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
+ return nil, err
+ }
+ case "log":
+ var driver string
+ if err := d.ParseParams(&driver, &srv.LogPath); err != nil {
+ return nil, err
+ }
+ if driver != "fs" {
+ return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver)
+ }
+ case "max-user-networks":
+ var max string
+ if err := d.ParseParams(&max); err != nil {
+ return nil, err
+ }
+ var err error
+ if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ case "multi-upstream-mode":
+ var str string
+ if err := d.ParseParams(&str); err != nil {
+ return nil, err
+ }
+ v, err := strconv.ParseBool(str)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ srv.MultiUpstream = v
+ case "upstream-user-ip":
+ if len(srv.UpstreamUserIPs) > 0 {
+ return nil, fmt.Errorf("directive %q: can only be specified once", d.Name)
+ }
+ var hasIPv4, hasIPv6 bool
+ for _, s := range d.Params {
+ _, n, err := net.ParseCIDR(s)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
+ }
+ if n.IP.To4() == nil {
+ if hasIPv6 {
+ return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name)
+ }
+ hasIPv6 = true
+ } else {
+ if hasIPv4 {
+ return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name)
+ }
+ hasIPv4 = true
+ }
+ srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
+ }
+ default:
+ return nil, fmt.Errorf("unknown directive %q", d.Name)
+ }
+ }
+
+ return srv, nil
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/time/rate"
+ "gopkg.in/irc.v3"
+)
+
+// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
+// reading and writing IRC messages.
+type ircConn interface {
+ ReadMessage() (*irc.Message, error)
+ WriteMessage(*irc.Message) error
+ Close() error
+ SetReadDeadline(time.Time) error
+ SetWriteDeadline(time.Time) error
+ RemoteAddr() net.Addr
+ LocalAddr() net.Addr
+}
+
+func newNetIRCConn(c net.Conn) ircConn {
+ type netConn net.Conn
+ return struct {
+ *irc.Conn
+ netConn
+ }{irc.NewConn(c), c}
+}
+
+type connOptions struct {
+ Logger Logger
+ RateLimitDelay time.Duration
+ RateLimitBurst int
+}
+
+type conn struct {
+ conn ircConn
+ srv *Server
+ logger Logger
+
+ lock sync.Mutex
+ outgoing chan<- *irc.Message
+ closed bool
+ closedCh chan struct{}
+}
+
+func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
+ outgoing := make(chan *irc.Message, 64)
+ c := &conn{
+ conn: ic,
+ srv: srv,
+ outgoing: outgoing,
+ logger: options.Logger,
+ closedCh: make(chan struct{}),
+ }
+
+ go func() {
+ ctx, cancel := c.NewContext(context.Background())
+ defer cancel()
+
+ rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
+ for msg := range outgoing {
+ if err := rl.Wait(ctx); err != nil {
+ break
+ }
+
+ c.logger.Debugf("sent: %v", msg)
+ c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
+ if err := c.conn.WriteMessage(msg); err != nil {
+ c.logger.Printf("failed to write message: %v", err)
+ break
+ }
+ }
+ if err := c.conn.Close(); err != nil && !isErrClosed(err) {
+ c.logger.Printf("failed to close connection: %v", err)
+ } else {
+ c.logger.Debugf("connection closed")
+ }
+ // Drain the outgoing channel to prevent SendMessage from blocking
+ for range outgoing {
+ // This space is intentionally left blank
+ }
+ }()
+
+ c.logger.Debugf("new connection")
+ return c
+}
+
+func (c *conn) isClosed() bool {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ return c.closed
+}
+
+// Close closes the connection. It is safe to call from any goroutine.
+func (c *conn) Close() error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return fmt.Errorf("connection already closed")
+ }
+
+ err := c.conn.Close()
+ c.closed = true
+ close(c.outgoing)
+ close(c.closedCh)
+ return err
+}
+
+func (c *conn) ReadMessage() (*irc.Message, error) {
+ msg, err := c.conn.ReadMessage()
+ if isErrClosed(err) {
+ return nil, io.EOF
+ } else if err != nil {
+ return nil, err
+ }
+
+ c.logger.Debugf("received: %v", msg)
+ return msg, nil
+}
+
+// SendMessage queues a new outgoing message. It is safe to call from any
+// goroutine.
+//
+// If the connection is closed before the message is sent, SendMessage silently
+// drops the message.
+func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return
+ }
+
+ select {
+ case c.outgoing <- msg:
+ // Success
+ case <-ctx.Done():
+ c.logger.Printf("failed to send message: %v", ctx.Err())
+ }
+}
+
+func (c *conn) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+
+func (c *conn) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+// NewContext returns a copy of the parent context with a new Done channel. The
+// returned context's Done channel is closed when the connection is closed,
+// when the returned cancel function is called, or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// Canceling this context releases resources associated with it, so code should
+// call cancel as soon as the operations running in this Context complete.
+func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(parent)
+
+ go func() {
+ defer cancel()
+
+ select {
+ case <-ctx.Done():
+ // The parent context has been cancelled, or the caller has called
+ // cancel()
+ case <-c.closedCh:
+ // The connection has been closed
+ }
+ }()
+
+ return ctx, cancel
+}
--- /dev/null
+#!/bin/sh -eu
+
+# Converts a log dir to its case-mapped form.
+#
+# suika needs to be stopped for this script to work properly. The script may
+# re-order messages that happened within the same second interval if merging
+# two daily log files is necessary.
+#
+# usage: casemap-logs.sh <directory>
+
+root="$1"
+
+for net_dir in "$root"/*/*; do
+ for chan in $(ls "$net_dir"); do
+ cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')"
+ if [ "$chan" = "$cm_chan" ]; then
+ continue
+ fi
+
+ if ! [ -d "$net_dir/$cm_chan" ]; then
+ echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ mv "$net_dir/$chan" "$net_dir/$cm_chan"
+ continue
+ fi
+
+ echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ for day in $(ls "$net_dir/$chan"); do
+ if ! [ -e "$net_dir/$cm_chan/$day" ]; then
+ echo >&2 " Moving log file: '$day'"
+ mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day"
+ continue
+ fi
+
+ echo >&2 " Merging log file: '$day'"
+ sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new"
+ mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day"
+ rm "$net_dir/$chan/$day"
+ done
+
+ rmdir "$net_dir/$chan"
+ done
+done
--- /dev/null
+# Clients
+
+This page describes how to configure IRC clients to better integrate with soju.
+
+Also see the [IRCv3 support tables] for a more general list of clients.
+
+# catgirl
+
+catgirl doesn't properly implement cap-3.2, so many capabilities will be
+disabled. catgirl developers have publicly stated that supporting bouncers such
+as soju is a non-goal.
+
+# [Emacs]
+
+There are two clients provided with Emacs. They require some setup to work
+properly.
+
+## Erc
+
+You need to explicitly set the username, which is the defcustom
+`erc-email-userid`.
+
+```elisp
+(setq erc-email-userid "<username>/irc.libera.chat") ;; Example with Libera.Chat
+(defun run-erc ()
+ (interactive)
+ (erc-tls :server "<server>"
+ :port 6697
+ :nick "<nick>"
+ :password "<password>"))
+```
+
+Then run `M-x run-erc`.
+
+## Rcirc
+
+The only thing needed here is the general config:
+
+```elisp
+(setq rcirc-server-alist
+ '(("<server>"
+ :port 6697
+ :encryption tls
+ :nick "<nick>"
+ :user-name "<username>/irc.libera.chat" ;; Example with Libera.Chat
+ :password "<password>")))
+```
+
+Then run `M-x irc`.
+
+# [gamja]
+
+gamja has been designed together with soju, so should have excellent
+integration. gamja supports many IRCv3 features including chat history.
+gamja also provides UI to manage soju networks via the
+`soju.im/bouncer-networks` extension.
+
+# [goguma]
+
+Much like gamja, goguma has been designed together with soju, so should have
+excellent integration. goguma supports many IRCv3 features including chat
+history. goguma should seamlessly connect to all networks configured in soju via
+the `soju.im/bouncer-networks` extension.
+
+# [Hexchat]
+
+Hexchat has support for a small set of IRCv3 capabilities. To prevent
+automatically reconnecting to channels parted from soju, and prevent buffering
+outgoing messages:
+
+ /set irc_reconnect_rejoin off
+ /set net_throttle off
+
+# [senpai]
+
+senpai is being developed with soju in mind, so should have excellent
+integration. senpai supports many IRCv3 features including chat history.
+
+# [Weechat]
+
+A [Weechat script] is available to provide better integration with soju.
+The script will automatically connect to all of your networks once a
+single connection to soju is set up in Weechat.
+
+On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them:
+
+ /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names
+ /save
+ /reconnect -all
+
+See `/help cap` for more information.
+
+[IRCv3 support tables]: https://ircv3.net/software/clients
+[gamja]: https://sr.ht/~emersion/gamja/
+[goguma]: https://sr.ht/~emersion/goguma/
+[senpai]: https://sr.ht/~taiite/senpai/
+[Weechat]: https://weechat.org/
+[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py
+[Hexchat]: https://hexchat.github.io/
+[Emacs]: https://www.gnu.org/software/emacs/
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+ "time"
+)
+
+type Database interface {
+ Close() error
+ Stats(ctx context.Context) (*DatabaseStats, error)
+
+ ListUsers(ctx context.Context) ([]User, error)
+ GetUser(ctx context.Context, username string) (*User, error)
+ StoreUser(ctx context.Context, user *User) error
+ DeleteUser(ctx context.Context, id int64) error
+
+ ListNetworks(ctx context.Context, userID int64) ([]Network, error)
+ StoreNetwork(ctx context.Context, userID int64, network *Network) error
+ DeleteNetwork(ctx context.Context, id int64) error
+ ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
+ StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
+ DeleteChannel(ctx context.Context, id int64) error
+
+ ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
+ StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
+
+ GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error)
+ StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error
+}
+
+func OpenDB(driver, source string) (Database, error) {
+ switch driver {
+ case "sqlite3":
+ return OpenSqliteDB(source)
+ case "postgres":
+ return OpenPostgresDB(source)
+ default:
+ return nil, fmt.Errorf("unsupported database driver: %q", driver)
+ }
+}
+
+type DatabaseStats struct {
+ Users int64
+ Networks int64
+ Channels int64
+}
+
+type User struct {
+ ID int64
+ Username string
+ Password string // hashed
+ Realname string
+ Admin bool
+}
+
+type SASL struct {
+ Mechanism string
+
+ Plain struct {
+ Username string
+ Password string
+ }
+
+ // TLS client certificate authentication.
+ External struct {
+ // X.509 certificate in DER form.
+ CertBlob []byte
+ // PKCS#8 private key in DER form.
+ PrivKeyBlob []byte
+ }
+}
+
+type Network struct {
+ ID int64
+ Name string
+ Addr string
+ Nick string
+ Username string
+ Realname string
+ Pass string
+ ConnectCommands []string
+ SASL SASL
+ Enabled bool
+}
+
+func (net *Network) GetName() string {
+ if net.Name != "" {
+ return net.Name
+ }
+ return net.Addr
+}
+
+func (net *Network) URL() (*url.URL, error) {
+ s := net.Addr
+ if !strings.Contains(s, "://") {
+ // This is a raw domain name, make it an URL with the default scheme
+ s = "ircs://" + s
+ }
+
+ u, err := url.Parse(s)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
+ }
+
+ return u, nil
+}
+
+func GetNick(user *User, net *Network) string {
+ if net.Nick != "" {
+ return net.Nick
+ }
+ return user.Username
+}
+
+func GetUsername(user *User, net *Network) string {
+ if net.Username != "" {
+ return net.Username
+ }
+ return GetNick(user, net)
+}
+
+func GetRealname(user *User, net *Network) string {
+ if net.Realname != "" {
+ return net.Realname
+ }
+ if user.Realname != "" {
+ return user.Realname
+ }
+ return GetNick(user, net)
+}
+
+type MessageFilter int
+
+const (
+ // TODO: use customizable user defaults for FilterDefault
+ FilterDefault MessageFilter = iota
+ FilterNone
+ FilterHighlight
+ FilterMessage
+)
+
+func parseFilter(filter string) (MessageFilter, error) {
+ switch filter {
+ case "default":
+ return FilterDefault, nil
+ case "none":
+ return FilterNone, nil
+ case "highlight":
+ return FilterHighlight, nil
+ case "message":
+ return FilterMessage, nil
+ }
+ return 0, fmt.Errorf("unknown filter: %q", filter)
+}
+
+type Channel struct {
+ ID int64
+ Name string
+ Key string
+
+ Detached bool
+ DetachedInternalMsgID string
+
+ RelayDetached MessageFilter
+ ReattachOn MessageFilter
+ DetachAfter time.Duration
+ DetachOn MessageFilter
+}
+
+type DeliveryReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Client string
+ InternalMsgID string
+}
+
+type ReadReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Timestamp time.Time
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "errors"
+ "fmt"
+ "math"
+ "strings"
+ "time"
+
+ _ "github.com/lib/pq"
+)
+
+const postgresQueryTimeout = 5 * time.Second
+
+const postgresConfigSchema = `
+CREATE TABLE IF NOT EXISTS "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+`
+//go:embed suika_psql_schema.sql
+var postgresSchema string
+
+var postgresMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
+ `
+ CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+ ALTER TABLE "Network"
+ ALTER COLUMN sasl_mechanism
+ TYPE sasl_mechanism
+ USING sasl_mechanism::sasl_mechanism;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type PostgresDB struct {
+ db *sql.DB
+}
+
+func OpenPostgresDB(source string) (Database, error) {
+ sqlPostgresDB, err := sql.Open("postgres", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &PostgresDB{db: sqlPostgresDB}
+ if err := db.upgrade(); err != nil {
+ sqlPostgresDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *PostgresDB) upgrade() error {
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if _, err := tx.Exec(postgresConfigSchema); err != nil {
+ return fmt.Errorf("failed to create Config table: %s", err)
+ }
+
+ var version int
+ err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return fmt.Errorf("failed to query schema version: %s", err)
+ }
+
+ if version == len(postgresMigrations) {
+ return nil
+ }
+ if version > len(postgresMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version)
+ }
+
+ if version == 0 {
+ if _, err := tx.Exec(postgresSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %s", err)
+ }
+ } else {
+ for i := version; i < len(postgresMigrations); i++ {
+ if _, err := tx.Exec(postgresMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
+ ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) Close() error {
+ return db.db.Close()
+}
+
+func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM "User") AS users,
+ (SELECT COUNT(*) FROM "Network") AS networks,
+ (SELECT COUNT(*) FROM "Channel") AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ `SELECT id, username, password, admin, realname FROM "User"`)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ password := toNullString(user.Password)
+ realname := toNullString(user.Realname)
+
+ var err error
+ if user.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "User" (username, password, admin, realname)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`,
+ user.Username, password, user.Admin, realname).Scan(&user.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "User"
+ SET password = $1, admin = $2, realname = $3
+ WHERE id = $4`,
+ password, user.Admin, realname, user.ID)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
+ sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
+ FROM "Network"
+ WHERE "user" = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ netName := toNullString(network.Name)
+ nick := toNullString(network.Nick)
+ netUsername := toNullString(network.Username)
+ realname := toNullString(network.Realname)
+ pass := toNullString(network.Pass)
+ connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ var err error
+ if network.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
+ sasl_external_key, enabled)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ RETURNING id`,
+ userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Network"
+ SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
+ connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
+ sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
+ enabled = $14
+ WHERE id = $1`,
+ network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
+ detach_on
+ FROM "Channel"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ key := toNullString(ch.Key)
+ detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
+
+ var err error
+ if ch.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
+ detach_after, detach_on)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+ RETURNING id`,
+ networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Channel"
+ SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
+ relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
+ WHERE id = $1`,
+ ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM "DeliveryReceipt"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx,
+ `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
+ networkID, client)
+ if err != nil {
+ return err
+ }
+
+ stmt, err := tx.PrepareContext(ctx, `
+ INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`)
+ if err != nil {
+ return err
+ }
+ defer stmt.Close()
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+ err := stmt.
+ QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
+ Scan(&rcpt.ID)
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`,
+ networkID, name)
+ if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ return receipt, nil
+}
+
+func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "ReadReceipt"
+ SET timestamp = $1
+ WHERE id = $2`,
+ receipt.Timestamp, receipt.ID)
+ } else {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "ReadReceipt" (network, target, timestamp)
+ VALUES ($1, $2, $3)
+ RETURNING id`,
+ networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID)
+ }
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "os"
+ "testing"
+)
+
+// PostgreSQL version 0 schema. DO NOT EDIT.
+const postgresV0Schema = `
+CREATE TABLE "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+
+INSERT INTO "Config" (id, version) VALUES (1, 1);
+
+CREATE TABLE "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TABLE "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA DEFAULT NULL,
+ sasl_external_key BYTEA DEFAULT NULL,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+
+CREATE TABLE "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+
+CREATE TABLE "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+`
+
+func openTempPostgresDB(t *testing.T) *sql.DB {
+ source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
+ if !ok {
+ t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
+ }
+
+ db, err := sql.Open("postgres", source)
+ if err != nil {
+ t.Fatalf("failed to connect to PostgreSQL: %v", err)
+ }
+
+ // Store all tables in a temporary schema which will be dropped when the
+ // connection to PostgreSQL is closed.
+ db.SetMaxOpenConns(1)
+ if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
+ t.Fatalf("failed to set PostgreSQL search_path: %v", err)
+ }
+
+ return db
+}
+
+func TestPostgresMigrations(t *testing.T) {
+ sqlDB := openTempPostgresDB(t)
+ if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &PostgresDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("PostgresDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "fmt"
+ "math"
+ "strings"
+ "sync"
+ "time"
+
+ _ "modernc.org/sqlite"
+)
+
+const sqliteQueryTimeout = 5 * time.Second
+
+//go:embed suika_sqlite_schema.sql
+var sqliteSchema string
+
+var sqliteMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
+ "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
+ "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
+ "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
+ "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
+ `
+ CREATE TABLE IF NOT EXISTS UserNew (
+ id INTEGER PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin INTEGER NOT NULL DEFAULT 0
+ );
+ INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
+ DROP TABLE User;
+ ALTER TABLE UserNew RENAME TO User;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user INTEGER NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BLOB DEFAULT NULL,
+ sasl_external_key BLOB DEFAULT NULL,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT Network.id, name, User.id as user, addr, nick,
+ Network.username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key
+ FROM Network
+ JOIN User ON Network.user = User.username;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255),
+ internal_msgid VARCHAR(255) NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+ );
+ `,
+ "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
+ "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
+ "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT id, name, user, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key,
+ enabled
+ FROM Network;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type SqliteDB struct {
+ lock sync.RWMutex
+ db *sql.DB
+}
+
+func OpenSqliteDB(source string) (Database, error) {
+ sqlSqliteDB, err := sql.Open("sqlite", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &SqliteDB{db: sqlSqliteDB}
+ if err := db.upgrade(); err != nil {
+ sqlSqliteDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *SqliteDB) Close() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+ return db.db.Close()
+}
+
+func (db *SqliteDB) upgrade() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ var version int
+ if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
+ return fmt.Errorf("failed to query schema version: %v", err)
+ }
+
+ if version == len(sqliteMigrations) {
+ return nil
+ } else if version > len(sqliteMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version)
+ }
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if version == 0 {
+ if _, err := tx.Exec(sqliteSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %v", err)
+ }
+ } else {
+ for i := version; i < len(sqliteMigrations); i++ {
+ if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ // For some reason prepared statements don't work here
+ _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM User) AS users,
+ (SELECT COUNT(*) FROM Network) AS networks,
+ (SELECT COUNT(*) FROM Channel) AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func toNullString(s string) sql.NullString {
+ return sql.NullString{
+ String: s,
+ Valid: s != "",
+ }
+}
+
+func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ "SELECT id, username, password, admin, realname FROM User")
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ "SELECT id, password, admin, realname FROM User WHERE username = ?",
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("username", user.Username),
+ sql.Named("password", toNullString(user.Password)),
+ sql.Named("admin", user.Admin),
+ sql.Named("realname", toNullString(user.Realname)),
+ }
+
+ var err error
+ if user.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE User SET password = :password, admin = :admin,
+ realname = :realname WHERE username = :username`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ User(username, password, admin, realname)
+ VALUES (:username, :password, :admin, :realname)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ user.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
+
+func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt
+ WHERE id IN (
+ SELECT DeliveryReceipt.id
+ FROM DeliveryReceipt
+ JOIN Network ON DeliveryReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt
+ WHERE id IN (
+ SELECT ReadReceipt.id
+ FROM ReadReceipt
+ JOIN Network ON ReadReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM Channel
+ WHERE id IN (
+ SELECT Channel.id
+ FROM Channel
+ JOIN Network ON Channel.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key, enabled
+ FROM Network
+ WHERE user = ?`,
+ userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ args := []interface{}{
+ sql.Named("name", toNullString(network.Name)),
+ sql.Named("addr", network.Addr),
+ sql.Named("nick", toNullString(network.Nick)),
+ sql.Named("username", toNullString(network.Username)),
+ sql.Named("realname", toNullString(network.Realname)),
+ sql.Named("pass", toNullString(network.Pass)),
+ sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
+ sql.Named("sasl_mechanism", saslMechanism),
+ sql.Named("sasl_plain_username", saslPlainUsername),
+ sql.Named("sasl_plain_password", saslPlainPassword),
+ sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
+ sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
+ sql.Named("enabled", network.Enabled),
+
+ sql.Named("id", network.ID), // only for UPDATE
+ sql.Named("user", userID), // only for INSERT
+ }
+
+ var err error
+ if network.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE Network
+ SET name = :name, addr = :addr, nick = :nick, username = :username,
+ realname = :realname, pass = :pass, connect_commands = :connect_commands,
+ sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
+ sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
+ enabled = :enabled
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO Network(user, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
+ VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
+ :connect_commands, :sasl_mechanism, :sasl_plain_username,
+ :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ network.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `SELECT
+ id, name, key, detached, detached_internal_msgid,
+ relay_detached, reattach_on, detach_after, detach_on
+ FROM Channel
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("network", networkID),
+ sql.Named("name", ch.Name),
+ sql.Named("key", toNullString(ch.Key)),
+ sql.Named("detached", ch.Detached),
+ sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)),
+ sql.Named("relay_detached", ch.RelayDetached),
+ sql.Named("reattach_on", ch.ReattachOn),
+ sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))),
+ sql.Named("detach_on", ch.DetachOn),
+
+ sql.Named("id", ch.ID), // only for UPDATE
+ }
+
+ var err error
+ if ch.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `UPDATE Channel
+ SET network = :network, name = :name, key = :key, detached = :detached,
+ detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached,
+ reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
+ VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...)
+ if err != nil {
+ return err
+ }
+ ch.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
+ return err
+}
+
+func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM DeliveryReceipt
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ var client sql.NullString
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ rcpt.Client = client.String
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
+ networkID, toNullString(client))
+ if err != nil {
+ return err
+ }
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+
+ res, err := tx.ExecContext(ctx, `
+ INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
+ VALUES (:network, :target, :client, :internal_msgid)`,
+ sql.Named("network", networkID),
+ sql.Named("target", rcpt.Target),
+ sql.Named("client", toNullString(client)),
+ sql.Named("internal_msgid", rcpt.InternalMsgID))
+ if err != nil {
+ return err
+ }
+ rcpt.ID, err = res.LastInsertId()
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx, `
+ SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`,
+ sql.Named("network", networkID),
+ sql.Named("target", name),
+ )
+ var timestamp string
+ if err := row.Scan(&receipt.ID, ×tamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if t, err := time.Parse(serverTimeLayout, timestamp); err != nil {
+ return nil, err
+ } else {
+ receipt.Timestamp = t
+ }
+ return receipt, nil
+}
+
+func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("id", receipt.ID),
+ sql.Named("timestamp", formatServerTime(receipt.Timestamp)),
+ sql.Named("network", networkID),
+ sql.Named("target", receipt.Target),
+ }
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ ReadReceipt(network, target, timestamp)
+ VALUES (:network, :target, :timestamp)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ receipt.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "testing"
+)
+
+// SQLite version 0 schema. DO NOT EDIT.
+const sqliteV0Schema = `
+CREATE TABLE User (
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255)
+);
+
+CREATE TABLE Network (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user VARCHAR(255) NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+
+CREATE TABLE Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+PRAGMA user_version = 1;
+`
+
+func TestSqliteMigrations(t *testing.T) {
+ sqlDB, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+
+ if _, err := sqlDB.Exec(sqliteV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &SqliteDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("SqliteDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go.
+//
+// # Copyright (C) 2020 The soju Contributors
+// # Copyright (C) 2023-present Izuru Yakumo et al.
+//
+// suika is covered by the AGPLv3 license:
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <https://www.gnu.org/licenses/>.
+package suika
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-CONFIG 5
+.Os
+.Sh NAME
+.Nm suika-config
+.Nd Configuration file for the IRC bouncer
+.Sh SYNOPSIS
+.Bk -words
+listen ircs://
+.Pp
+tls cert.pem key.pem
+.Pp
+hostname example.org
+.Ek
+.Sh DESCRIPTION
+This document describes the format of the configuration
+file used by
+.Xr suika 1
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It listen Ar uri
+With this you can control on what
+ports/protocols
+.Xr suika 1
+listens on, it supports
+irc (cleartext IRC), ircs (IRC with TLS), and unix
+(IRC over Unix domain sockets)
+.It hostname Ar hostname
+Server hostname, if unset, the system one is used.
+.It title Ar title
+Server title, this will be sent as the ISUPPORT NETWORK value when
+clients don't select a specific network.
+.It tls Ar cert Ar key
+Enable TLS support, the certificate and key files must be
+PEM-encoded.
+.It db Ar driver Ar path
+Set the database driver for user, network and channel storage.
+By default a SQLite 3 database is opened in
+.Pa ./suika.db
+Supported drivers are sqlite and postgres, the former
+expects a path to the database file, and the latter
+a space-separated list of key=value parameters,
+e.g. host=localhost dbname=suika
+.It log fs Ar path
+Path to the bouncer logs directory, or empty to disable
+logging.
+By default, logging is disabled.
+.It max-user-networks Ar limit
+Maximum number of networks per user, by default
+there is no limit.
+.It motd Ar path
+Path to the MOTD file, its contents are sent to clients
+which aren't bound to a particular network.
+By default, no MOTD is sent.
+.It multi-upstream-mode Ar bool
+Globally enable or disable multi-upstream mode.
+By default, it is enabled.
+.It upstream-user-ip Ar cidr
+Enable per-user-IP addresses.
+One IPv4 and/or one IPv6 range can be specified in CIDR notation.
+One IP address per range will be assigned to each user as the
+source address when connecting to an upstream network.
+This can be useful to avoid having the whole bouncer banned from
+an upstream network because of one malicious user.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-ZNC-IMPORT 1
+.Os
+.Sh NAME
+.Nm suika-znc-import
+.Nd Migration utility for moving from ZNC
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar suika config file
+.Op Fl user Ar username
+.Op Fl network Ar name
+.Sh DESCRIPTION
+Imports configuration from a ZNC file.
+Users and networks are merged if they already exist in the
+.Xr suika 1
+database.
+ZNC settings overwrite existing
+.Xr suika 1
+settings
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It config Ar suika config file
+Path to
+.Xr suika-config 5
+.It user Ar username
+Limit import to username, may be specified multiple times.
+It network Ar name
+Limit import to network, may be specified multiple times.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA 1
+.Os
+.Sh NAME
+.Nm suika
+.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar path
+.Op Fl debug
+.Op Fl listen Ar uri
+.Sh DESCRIPTION
+.Nm
+is an user-friendly IRC bouncer, it connects to upstream
+IRC servers on behalf of the user to provide extra features.
+.Bl -tag -width 6n
+.It Multiple separate users sharing the same bouncer
+.It Clients connecting to multiple upstream servers (via a single connection)
+.It Sending the backlog with per-client buffers
+.El
+.Pp
+When joining a channel, the channel will be saved
+and automatically joined on the next connection.
+When registering or authenticating with NickServ, the credentials will be saved
+and automatically used on the next connection if the server supports SASL.
+When parting a channel with the reason "detach", the channel will be
+detached instead of being left.
+When all clients are disconnected from the bouncer,
+the user is automatically marked as away.
+.Pp
+.Nm
+supports two connection modes:
+.Bl -tag -width 6n
+.It Single upstream mode
+One downstream connection maps to one upstream connection
+.Pp
+To enable this mode, connect to the bouncer
+with the username "<username>/<network>".
+.Pp
+If the bouncer isn't connected to the upstream server,
+it will get automatically added.
+.Pp
+Then channels can be joined and parted as if
+you were directly connected to the upstream server.
+.It Multiple upstream mode
+One downstream connection maps to multiple upstream connections.
+Channels and nicks are suffixed with the network name.
+To join a channel, you need to use the suffix too: /join #channel/network.
+Same applies to messages sent to users.
+.El
+.Pp
+For per-client history to work, clients need to indicate their name.
+This can be done by adding a "@<client>" suffix to the username.
+.Pp
+.Nm
+will reload the configuration file, the TLS certificate/key and
+the MOTD file when it receives the HUP signal.
+The configuration options listen, db and log cannot be reloaded.
+.Pp
+Administrators can broadcast a message to all bouncer users via
+/notice $<hostname> <text>, or via /notice $<text> in multi-upstream mode.
+All currently connected bouncer users will receive the message
+from the special BouncerServ service.
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKADB 1
+.Os
+.Sh NAME
+.Nm suikadb
+.Nd Basic user manipulation for
+.Xr suika 1
+.Sh SYNOPSIS
+.Nm
+.Op create-user
+.Op change-password
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+type ircError struct {
+ Message *irc.Message
+}
+
+func (err ircError) Error() string {
+ return err.Message.String()
+}
+
+func newUnknownCommandError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{
+ "*",
+ cmd,
+ "Unknown command",
+ },
+ }}
+}
+
+func newNeedMoreParamsError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NEEDMOREPARAMS,
+ Params: []string{
+ "*",
+ cmd,
+ "Not enough parameters",
+ },
+ }}
+}
+
+func newChatHistoryError(subcommand string, target string) ircError {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"},
+ }}
+}
+
+// authError is an authentication error.
+type authError struct {
+ // Internal error cause. This will not be revealed to the user.
+ err error
+ // Error cause which can safely be sent to the user without compromising
+ // security.
+ reason string
+}
+
+func (err *authError) Error() string {
+ return err.err.Error()
+}
+
+func (err *authError) Unwrap() error {
+ return err.err
+}
+
+// authErrorReason returns the user-friendly reason of an authentication
+// failure.
+func authErrorReason(err error) string {
+ if authErr, ok := err.(*authError); ok {
+ return authErr.reason
+ } else {
+ return "Authentication failed"
+ }
+}
+
+func newInvalidUsernameOrPasswordError(err error) error {
+ return &authError{
+ err: err,
+ reason: "Invalid username or password",
+ }
+}
+
+func parseBouncerNetID(subcommand, s string) (int64, error) {
+ id, err := strconv.ParseInt(s, 10, 64)
+ if err != nil {
+ return 0, ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"},
+ }}
+ }
+ return id, nil
+}
+
+func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) {
+ u, err := network.URL()
+ if err != nil {
+ return
+ }
+
+ hasHostPort := true
+ switch u.Scheme {
+ case "ircs":
+ attrs["tls"] = irc.TagValue("1")
+ case "irc":
+ attrs["tls"] = irc.TagValue("0")
+ default: // e.g. unix://
+ hasHostPort = false
+ }
+ if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort {
+ attrs["host"] = irc.TagValue(host)
+ attrs["port"] = irc.TagValue(port)
+ } else if hasHostPort {
+ attrs["host"] = irc.TagValue(u.Host)
+ }
+}
+
+func getNetworkAttrs(network *network) irc.Tags {
+ state := "disconnected"
+ if uc := network.conn; uc != nil {
+ state = "connected"
+ }
+
+ attrs := irc.Tags{
+ "name": irc.TagValue(network.GetName()),
+ "state": irc.TagValue(state),
+ "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)),
+ }
+
+ if network.Username != "" {
+ attrs["username"] = irc.TagValue(network.Username)
+ }
+ if realname := GetRealname(&network.user.User, &network.Network); realname != "" {
+ attrs["realname"] = irc.TagValue(realname)
+ }
+
+ fillNetworkAddrAttrs(attrs, &network.Network)
+
+ return attrs
+}
+
+func networkAddrFromAttrs(attrs irc.Tags) string {
+ host, ok := attrs.GetTag("host")
+ if !ok {
+ return ""
+ }
+
+ addr := host
+ if port, ok := attrs.GetTag("port"); ok {
+ addr += ":" + port
+ }
+
+ if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" {
+ addr = "irc://" + tlsStr
+ }
+
+ return addr
+}
+
+func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error {
+ addrAttrs := irc.Tags{}
+ fillNetworkAddrAttrs(addrAttrs, record)
+
+ updateAddr := false
+ for k, v := range attrs {
+ s := string(v)
+ switch k {
+ case "host", "port", "tls":
+ updateAddr = true
+ addrAttrs[k] = v
+ case "name":
+ record.Name = s
+ case "nickname":
+ record.Nick = s
+ case "username":
+ record.Username = s
+ case "realname":
+ record.Realname = s
+ case "pass":
+ record.Pass = s
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"},
+ }}
+ }
+ }
+
+ if updateAddr {
+ record.Addr = networkAddrFromAttrs(addrAttrs)
+ if record.Addr == "" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"},
+ }}
+ }
+ }
+
+ return nil
+}
+
+// illegalNickChars is the list of characters forbidden in a nickname.
+//
+// ' ' and ':' break the IRC message wire format
+// '@' and '!' break prefixes
+// '*' breaks masks and is the reserved nickname for registration
+// '?' breaks masks
+// '$' breaks server masks in PRIVMSG/NOTICE
+// ',' breaks lists
+// '.' is reserved for server names
+const illegalNickChars = " :@!*?$,."
+
+// permanentDownstreamCaps is the list of always-supported downstream
+// capabilities.
+var permanentDownstreamCaps = map[string]string{
+ "batch": "",
+ "cap-notify": "",
+ "echo-message": "",
+ "invite-notify": "",
+ "message-tags": "",
+ "server-time": "",
+ "setname": "",
+
+ "soju.im/bouncer-networks": "",
+ "soju.im/bouncer-networks-notify": "",
+ "soju.im/read": "",
+}
+
+// needAllDownstreamCaps is the list of downstream capabilities that
+// require support from all upstreams to be enabled
+var needAllDownstreamCaps = map[string]string{
+ "account-notify": "",
+ "account-tag": "",
+ "away-notify": "",
+ "extended-join": "",
+ "multi-prefix": "",
+
+ "draft/extended-monitor": "",
+}
+
+// passthroughIsupport is the set of ISUPPORT tokens that are directly passed
+// through from the upstream server to downstream clients.
+//
+// This is only effective in single-upstream mode.
+var passthroughIsupport = map[string]bool{
+ "AWAYLEN": true,
+ "BOT": true,
+ "CHANLIMIT": true,
+ "CHANMODES": true,
+ "CHANNELLEN": true,
+ "CHANTYPES": true,
+ "CLIENTTAGDENY": true,
+ "ELIST": true,
+ "EXCEPTS": true,
+ "EXTBAN": true,
+ "HOSTLEN": true,
+ "INVEX": true,
+ "KICKLEN": true,
+ "MAXLIST": true,
+ "MAXTARGETS": true,
+ "MODES": true,
+ "MONITOR": true,
+ "NAMELEN": true,
+ "NETWORK": true,
+ "NICKLEN": true,
+ "PREFIX": true,
+ "SAFELIST": true,
+ "TARGMAX": true,
+ "TOPICLEN": true,
+ "USERLEN": true,
+ "UTF8ONLY": true,
+ "WHOX": true,
+}
+
+type downstreamSASL struct {
+ server sasl.Server
+ plainUsername, plainPassword string
+ pendingResp bytes.Buffer
+}
+
+type downstreamConn struct {
+ conn
+
+ id uint64
+
+ registered bool
+ user *user
+ nick string
+ nickCM string
+ rawUsername string
+ networkName string
+ clientName string
+ realname string
+ hostname string
+ account string // RPL_LOGGEDIN/OUT state
+ password string // empty after authentication
+ network *network // can be nil
+ isMultiUpstream bool
+
+ negotiatingCaps bool
+ capVersion int
+ supportedCaps map[string]string
+ caps map[string]bool
+ sasl *downstreamSASL
+
+ lastBatchRef uint64
+
+ monitored casemapMap
+}
+
+func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
+ remoteAddr := ic.RemoteAddr().String()
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
+ options := connOptions{Logger: logger}
+ dc := &downstreamConn{
+ conn: *newConn(srv, ic, &options),
+ id: id,
+ nick: "*",
+ nickCM: "*",
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ monitored: newCasemapMap(0),
+ }
+ dc.hostname = remoteAddr
+ if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
+ dc.hostname = host
+ }
+ for k, v := range permanentDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ dc.supportedCaps["sasl"] = "PLAIN"
+ // TODO: this is racy, we should only enable chathistory after
+ // authentication and then check that user.msgStore implements
+ // chatHistoryMessageStore
+ if srv.Config().LogPath != "" {
+ dc.supportedCaps["draft/chathistory"] = ""
+ }
+ return dc
+}
+
+func (dc *downstreamConn) prefix() *irc.Prefix {
+ return &irc.Prefix{
+ Name: dc.nick,
+ User: dc.user.Username,
+ Host: dc.hostname,
+ }
+}
+
+func (dc *downstreamConn) forEachNetwork(f func(*network)) {
+ if dc.network != nil {
+ f(dc.network)
+ } else if dc.isMultiUpstream {
+ for _, network := range dc.user.networks {
+ f(network)
+ }
+ }
+}
+
+func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ dc.user.forEachUpstream(func(uc *upstreamConn) {
+ if dc.network != nil && uc.network != dc.network {
+ return
+ }
+ f(uc)
+ })
+}
+
+// upstream returns the upstream connection, if any. If there are zero or if
+// there are multiple upstream connections, it returns nil.
+func (dc *downstreamConn) upstream() *upstreamConn {
+ if dc.network == nil {
+ return nil
+ }
+ return dc.network.conn
+}
+
+func isOurNick(net *network, nick string) bool {
+ // TODO: this doesn't account for nick changes
+ if net.conn != nil {
+ return net.casemap(nick) == net.conn.nickCM
+ }
+ // We're not currently connected to the upstream connection, so we don't
+ // know whether this name is our nickname. Best-effort: use the network's
+ // configured nickname and hope it was the one being used when we were
+ // connected.
+ return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network))
+}
+
+// marshalEntity converts an upstream entity name (ie. channel or nick) into a
+// downstream entity name.
+//
+// This involves adding a "/<network>" suffix if the entity isn't the current
+// user.
+func (dc *downstreamConn) marshalEntity(net *network, name string) string {
+ if isOurNick(net, name) {
+ return dc.nick
+ }
+ name = partialCasemap(net.casemap, name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal an entity for another network")
+ }
+ return name
+ }
+ return name + "/" + net.GetName()
+}
+
+func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix {
+ if isOurNick(net, prefix.Name) {
+ return dc.prefix()
+ }
+ prefix.Name = partialCasemap(net.casemap, prefix.Name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal a user prefix for another network")
+ }
+ return prefix
+ }
+ return &irc.Prefix{
+ Name: prefix.Name + "/" + net.GetName(),
+ User: prefix.User,
+ Host: prefix.Host,
+ }
+}
+
+// unmarshalEntityNetwork converts a downstream entity name (ie. channel or
+// nick) into an upstream entity name.
+//
+// This involves removing the "/<network>" suffix.
+func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) {
+ if dc.network != nil {
+ return dc.network, name, nil
+ }
+ if !dc.isMultiUpstream {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"},
+ }}
+ }
+
+ var net *network
+ if i := strings.LastIndexByte(name, '/'); i >= 0 {
+ network := name[i+1:]
+ name = name[:i]
+
+ for _, n := range dc.user.networks {
+ if network == n.GetName() {
+ net = n
+ break
+ }
+ }
+ }
+
+ if net == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Missing network suffix in name"},
+ }}
+ }
+
+ return net, name, nil
+}
+
+// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the
+// upstream connection and fails if the upstream is disconnected.
+func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) {
+ net, name, err := dc.unmarshalEntityNetwork(name)
+ if err != nil {
+ return nil, "", err
+ }
+
+ if net.conn == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Disconnected from upstream network"},
+ }}
+ }
+
+ return net.conn, name, nil
+}
+
+func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string {
+ if dc.upstream() != nil {
+ return text
+ }
+ // TODO: smarter parsing that ignores URLs
+ return strings.ReplaceAll(text, "/"+uc.network.GetName(), "")
+}
+
+func (dc *downstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := dc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (dc *downstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := dc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventDownstreamMessage{msg, dc}
+ }
+
+ return nil
+}
+
+// SendMessage sends an outgoing message.
+//
+// This can only called from the user goroutine.
+func (dc *downstreamConn) SendMessage(msg *irc.Message) {
+ if !dc.caps["message-tags"] {
+ if msg.Command == "TAGMSG" {
+ return
+ }
+ msg = msg.Copy()
+ for name := range msg.Tags {
+ supported := false
+ switch name {
+ case "time":
+ supported = dc.caps["server-time"]
+ case "account":
+ supported = dc.caps["account"]
+ }
+ if !supported {
+ delete(msg.Tags, name)
+ }
+ }
+ }
+ if !dc.caps["batch"] && msg.Tags["batch"] != "" {
+ msg = msg.Copy()
+ delete(msg.Tags, "batch")
+ }
+ if msg.Command == "JOIN" && !dc.caps["extended-join"] {
+ msg.Params = msg.Params[:1]
+ }
+ if msg.Command == "SETNAME" && !dc.caps["setname"] {
+ return
+ }
+ if msg.Command == "AWAY" && !dc.caps["away-notify"] {
+ return
+ }
+ if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] {
+ return
+ }
+ if msg.Command == "READ" && !dc.caps["soju.im/read"] {
+ return
+ }
+
+ dc.conn.SendMessage(context.TODO(), msg)
+}
+
+func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) {
+ dc.lastBatchRef++
+ ref := fmt.Sprintf("%v", dc.lastBatchRef)
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Tags: tags,
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: append([]string{"+" + ref, typ}, params...),
+ })
+ }
+
+ f(irc.TagValue(ref))
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: []string{"-" + ref},
+ })
+ }
+}
+
+// sendMessageWithID sends an outgoing message with the specified internal ID.
+func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
+ dc.SendMessage(msg)
+
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// advanceMessageWithID advances history to the specified message ID without
+// sending a message. This is useful e.g. for self-messages when echo-message
+// isn't enabled.
+func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// ackMsgID acknowledges that a message has been received.
+func (dc *downstreamConn) ackMsgID(id string) {
+ netID, entity, err := parseMsgID(id, nil)
+ if err != nil {
+ dc.logger.Printf("failed to ACK message ID %q: %v", id, err)
+ return
+ }
+
+ network := dc.user.getNetworkByID(netID)
+ if network == nil {
+ return
+ }
+
+ network.delivered.StoreID(entity, dc.clientName, id)
+}
+
+func (dc *downstreamConn) sendPing(msgID string) {
+ token := "suika-msgid-" + msgID
+ dc.SendMessage(&irc.Message{
+ Command: "PING",
+ Params: []string{token},
+ })
+}
+
+func (dc *downstreamConn) handlePong(token string) {
+ if !strings.HasPrefix(token, "suika-msgid-") {
+ dc.logger.Printf("received unrecognized PONG token %q", token)
+ return
+ }
+ msgID := strings.TrimPrefix(token, "suika-msgid-")
+ dc.ackMsgID(msgID)
+}
+
+// marshalMessage re-formats a message coming from an upstream connection so
+// that it's suitable for being sent on this downstream connection. Only
+// messages that may appear in logs are supported, except MODE messages which
+// may only appear in single-upstream mode.
+func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message {
+ msg = msg.Copy()
+ msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix)
+
+ if dc.network != nil {
+ return msg
+ }
+
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE", "TAGMSG":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "NICK":
+ // Nick change for another user
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "JOIN", "PART":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "KICK":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ msg.Params[1] = dc.marshalEntity(net, msg.Params[1])
+ case "TOPIC":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "QUIT", "SETNAME":
+ // This space is intentionally left blank
+ default:
+ panic(fmt.Sprintf("unexpected %q message", msg.Command))
+ }
+
+ return msg
+}
+
+func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ ctx, cancel := dc.conn.NewContext(ctx)
+ defer cancel()
+
+ ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout)
+ defer cancel()
+
+ switch msg.Command {
+ case "QUIT":
+ return dc.Close()
+ default:
+ if dc.registered {
+ return dc.handleMessageRegistered(ctx, msg)
+ } else {
+ return dc.handleMessageUnregistered(ctx, msg)
+ }
+ }
+}
+
+func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "NICK":
+ var nick string
+ if err := parseMessageParams(msg, &nick); err != nil {
+ return err
+ }
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, nick, "contains illegal characters"},
+ }}
+ }
+ nickCM := casemapASCII(nick)
+ if nickCM == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
+ }}
+ }
+ dc.nick = nick
+ dc.nickCM = nickCM
+ case "USER":
+ if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil {
+ return err
+ }
+ case "PASS":
+ if err := parseMessageParams(msg, &dc.password); err != nil {
+ return err
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "AUTHENTICATE":
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ } else if credentials == nil {
+ break
+ }
+
+ if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil {
+ dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err)
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, authErrorReason(err)},
+ })
+ break
+ }
+
+ // Technically we should send RPL_LOGGEDIN here. However we use
+ // RPL_LOGGEDIN to mirror the upstream connection status. Let's
+ // see how many clients that breaks. See:
+ // https://github.com/ircv3/ircv3-specifications/pull/476
+ dc.endSASL(nil)
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+
+ if dc.user == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
+ }}
+ }
+
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ var match *network
+ for _, net := range dc.user.networks {
+ if net.ID == id {
+ match = net
+ break
+ }
+ }
+ if match == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"},
+ }}
+ }
+
+ dc.networkName = match.GetName()
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+ return newUnknownCommandError(msg.Command)
+ }
+ if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps {
+ return dc.register(ctx)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
+ cmd = strings.ToUpper(cmd)
+
+ switch cmd {
+ case "LS":
+ if len(args) > 0 {
+ var err error
+ if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
+ return err
+ }
+ }
+ if !dc.registered && dc.capVersion >= 302 {
+ // Let downstream show everything it supports, and trim
+ // down the available capabilities when upstreams are
+ // known.
+ for k, v := range needAllDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ }
+
+ caps := make([]string, 0, len(dc.supportedCaps))
+ for k, v := range dc.supportedCaps {
+ if dc.capVersion >= 302 && v != "" {
+ caps = append(caps, k+"="+v)
+ } else {
+ caps = append(caps, k)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LS", strings.Join(caps, " ")},
+ })
+
+ if dc.capVersion >= 302 {
+ // CAP version 302 implicitly enables cap-notify
+ dc.caps["cap-notify"] = true
+ }
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "LIST":
+ var caps []string
+ for name, enabled := range dc.caps {
+ if enabled {
+ caps = append(caps, name)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LIST", strings.Join(caps, " ")},
+ })
+ case "REQ":
+ if len(args) == 0 {
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"},
+ }}
+ }
+
+ // TODO: atomically ack/nak the whole capability set
+ caps := strings.Fields(args[0])
+ ack := true
+ for _, name := range caps {
+ name = strings.ToLower(name)
+ enable := !strings.HasPrefix(name, "-")
+ if !enable {
+ name = strings.TrimPrefix(name, "-")
+ }
+
+ if enable == dc.caps[name] {
+ continue
+ }
+
+ _, ok := dc.supportedCaps[name]
+ if !ok {
+ ack = false
+ break
+ }
+
+ if name == "cap-notify" && dc.capVersion >= 302 && !enable {
+ // cap-notify cannot be disabled with CAP version 302
+ ack = false
+ break
+ }
+
+ dc.caps[name] = enable
+ }
+
+ reply := "NAK"
+ if ack {
+ reply = "ACK"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, reply, args[0]},
+ })
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "END":
+ dc.negotiatingCaps = false
+ default:
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Unknown CAP command"},
+ }}
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) {
+ defer func() {
+ if err != nil {
+ dc.sasl = nil
+ }
+ }()
+
+ if !dc.caps["sasl"] {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"},
+ }}
+ }
+ if len(msg.Params) == 0 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Missing AUTHENTICATE argument"},
+ }}
+ }
+ if msg.Params[0] == "*" {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ }}
+ }
+
+ var resp []byte
+ if dc.sasl == nil {
+ mech := strings.ToUpper(msg.Params[0])
+ var server sasl.Server
+ switch mech {
+ case "PLAIN":
+ server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
+ dc.sasl.plainUsername = username
+ dc.sasl.plainPassword = password
+ return nil
+ }))
+ default:
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)},
+ }}
+ }
+
+ dc.sasl = &downstreamSASL{server: server}
+ } else {
+ chunk := msg.Params[0]
+ if chunk == "+" {
+ chunk = ""
+ }
+
+ if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Response too long"},
+ }}
+ }
+
+ dc.sasl.pendingResp.WriteString(chunk)
+
+ if len(chunk) == maxSASLLength {
+ return nil, nil // Multi-line response, wait for the next command
+ }
+
+ resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String())
+ if err != nil {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Invalid base64-encoded response"},
+ }}
+ }
+
+ dc.sasl.pendingResp.Reset()
+ }
+
+ challenge, done, err := dc.sasl.server.Next(resp)
+ if err != nil {
+ return nil, err
+ } else if done {
+ return dc.sasl, nil
+ } else {
+ challengeStr := "+"
+ if len(challenge) > 0 {
+ challengeStr = base64.StdEncoding.EncodeToString(challenge)
+ }
+
+ // TODO: multi-line messages
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "AUTHENTICATE",
+ Params: []string{challengeStr},
+ })
+ return nil, nil
+ }
+}
+
+func (dc *downstreamConn) endSASL(msg *irc.Message) {
+ if dc.sasl == nil {
+ return
+ }
+
+ dc.sasl = nil
+
+ if msg != nil {
+ dc.SendMessage(msg)
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_SASLSUCCESS,
+ Params: []string{dc.nick, "SASL authentication successful"},
+ })
+ }
+}
+
+func (dc *downstreamConn) setSupportedCap(name, value string) {
+ prevValue, hasPrev := dc.supportedCaps[name]
+ changed := !hasPrev || prevValue != value
+ dc.supportedCaps[name] = value
+
+ if !dc.caps["cap-notify"] || !changed {
+ return
+ }
+
+ cap := name
+ if value != "" && dc.capVersion >= 302 {
+ cap = name + "=" + value
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "NEW", cap},
+ })
+}
+
+func (dc *downstreamConn) unsetSupportedCap(name string) {
+ _, hasPrev := dc.supportedCaps[name]
+ delete(dc.supportedCaps, name)
+ delete(dc.caps, name)
+
+ if !dc.caps["cap-notify"] || !hasPrev {
+ return
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "DEL", name},
+ })
+}
+
+func (dc *downstreamConn) updateSupportedCaps() {
+ supportedCaps := make(map[string]bool)
+ for cap := range needAllDownstreamCaps {
+ supportedCaps[cap] = true
+ }
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for cap, supported := range supportedCaps {
+ supportedCaps[cap] = supported && uc.caps[cap]
+ }
+ })
+
+ for cap, supported := range supportedCaps {
+ if supported {
+ dc.setSupportedCap(cap, needAllDownstreamCaps[cap])
+ } else {
+ dc.unsetSupportedCap(cap)
+ }
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") {
+ dc.setSupportedCap("sasl", "PLAIN")
+ } else if dc.network != nil {
+ dc.unsetSupportedCap("sasl")
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] {
+ // Strip "before-connect", because we require downstreams to be fully
+ // connected before attempting account registration.
+ values := strings.Split(uc.supportedCaps["draft/account-registration"], ",")
+ for i, v := range values {
+ if v == "before-connect" {
+ values = append(values[:i], values[i+1:]...)
+ break
+ }
+ }
+ dc.setSupportedCap("draft/account-registration", strings.Join(values, ","))
+ } else {
+ dc.unsetSupportedCap("draft/account-registration")
+ }
+
+ if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil {
+ dc.setSupportedCap("draft/event-playback", "")
+ } else {
+ dc.unsetSupportedCap("draft/event-playback")
+ }
+}
+
+func (dc *downstreamConn) updateNick() {
+ if uc := dc.upstream(); uc != nil && uc.nick != dc.nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ dc.nick = uc.nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+}
+
+func (dc *downstreamConn) updateRealname() {
+ if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{uc.realname},
+ })
+ dc.realname = uc.realname
+ }
+}
+
+func (dc *downstreamConn) updateAccount() {
+ var account string
+ if dc.network == nil {
+ account = dc.user.Username
+ } else if uc := dc.upstream(); uc != nil {
+ account = uc.account
+ } else {
+ return
+ }
+
+ if dc.account == account || !dc.caps["sasl"] {
+ return
+ }
+
+ if account != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDIN,
+ Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDOUT,
+ Params: []string{dc.nick, dc.prefix().String(), "You are logged out"},
+ })
+ }
+
+ dc.account = account
+}
+
+func sanityCheckServer(ctx context.Context, addr string) error {
+ ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
+ defer cancel()
+
+ conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ return conn.Close()
+}
+
+func unmarshalUsername(rawUsername string) (username, client, network string) {
+ username = rawUsername
+
+ i := strings.IndexAny(username, "/@")
+ j := strings.LastIndexAny(username, "/@")
+ if i >= 0 {
+ username = rawUsername[:i]
+ }
+ if j >= 0 {
+ if rawUsername[j] == '@' {
+ client = rawUsername[j+1:]
+ } else {
+ network = rawUsername[j+1:]
+ }
+ }
+ if i >= 0 && j >= 0 && i < j {
+ if rawUsername[i] == '@' {
+ client = rawUsername[i+1 : j]
+ } else {
+ network = rawUsername[i+1 : j]
+ }
+ }
+
+ return username, client, network
+}
+
+func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
+ username, clientName, networkName := unmarshalUsername(username)
+
+ u, err := dc.srv.db.GetUser(ctx, username)
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err))
+ }
+
+ // Password auth disabled
+ if u.Password == "" {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled"))
+ }
+
+ err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password"))
+ }
+
+ dc.user = dc.srv.getUser(username)
+ if dc.user == nil {
+ return fmt.Errorf("user not active")
+ }
+ dc.clientName = clientName
+ dc.networkName = networkName
+ return nil
+}
+
+func (dc *downstreamConn) register(ctx context.Context) error {
+ if dc.registered {
+ panic("tried to register twice")
+ }
+
+ if dc.sasl != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ }
+
+ password := dc.password
+ dc.password = ""
+ if dc.user == nil {
+ if password == "" {
+ if dc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"},
+ }}
+ } else {
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, "Authentication required"},
+ }}
+ }
+ }
+
+ if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil {
+ dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, authErrorReason(err)},
+ }}
+ }
+ }
+
+ _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername)
+ if dc.clientName == "" {
+ dc.clientName = fallbackClientName
+ } else if fallbackClientName != "" && dc.clientName != fallbackClientName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Client name mismatch in usernames"},
+ }}
+ }
+ if dc.networkName == "" {
+ dc.networkName = fallbackNetworkName
+ } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Network name mismatch in usernames"},
+ }}
+ }
+
+ dc.registered = true
+ dc.logger.Printf("registration complete for user %q", dc.user.Username)
+ return nil
+}
+
+func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
+ if dc.networkName == "" {
+ return nil
+ }
+
+ network := dc.user.getNetwork(dc.networkName)
+ if network == nil {
+ addr := dc.networkName
+ if !strings.ContainsRune(addr, ':') {
+ addr = addr + ":6697"
+ }
+
+ dc.logger.Printf("trying to connect to new network %q", addr)
+ if err := sanityCheckServer(ctx, addr); err != nil {
+ dc.logger.Printf("failed to connect to %q: %v", addr, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)},
+ }}
+ }
+
+ // Some clients only allow specifying the nickname (and use the
+ // nickname as a username too). Strip the network name from the
+ // nickname when auto-saving networks.
+ nick, _, _ := unmarshalUsername(dc.nick)
+
+ dc.logger.Printf("auto-saving network %q", dc.networkName)
+ var err error
+ network, err = dc.user.createNetwork(ctx, &Network{
+ Addr: dc.networkName,
+ Nick: nick,
+ Enabled: true,
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ dc.network = network
+ return nil
+}
+
+func (dc *downstreamConn) welcome(ctx context.Context) error {
+ if dc.user == nil || !dc.registered {
+ panic("tried to welcome an unregistered connection")
+ }
+
+ remoteAddr := dc.conn.RemoteAddr().String()
+ dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}
+
+ // TODO: doing this might take some time. We should do it in dc.register
+ // instead, but we'll potentially be adding a new network and this must be
+ // done in the user goroutine.
+ if err := dc.loadNetwork(ctx); err != nil {
+ return err
+ }
+
+ if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream {
+ dc.isMultiUpstream = true
+ }
+
+ dc.updateSupportedCaps()
+
+ isupport := []string{
+ fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit),
+ "CASEMAPPING=ascii",
+ }
+
+ if dc.network != nil {
+ isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID))
+ }
+ if title := dc.srv.Config().Title; dc.network == nil && title != "" {
+ isupport = append(isupport, "NETWORK="+encodeISUPPORT(title))
+ }
+ if dc.network == nil && !dc.isMultiUpstream {
+ isupport = append(isupport, "WHOX")
+ }
+
+ if uc := dc.upstream(); uc != nil {
+ for k := range passthroughIsupport {
+ v, ok := uc.isupport[k]
+ if !ok {
+ continue
+ }
+ if v != nil {
+ isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v))
+ } else {
+ isupport = append(isupport, k)
+ }
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WELCOME,
+ Params: []string{dc.nick, "Welcome to suika, " + dc.nick},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_YOURHOST,
+ Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MYINFO,
+ Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) {
+ dc.SendMessage(msg)
+ }
+ if uc := dc.upstream(); uc != nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + string(uc.modes)},
+ })
+ }
+ if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+o"},
+ })
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+
+ if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil {
+ for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) {
+ dc.SendMessage(msg)
+ }
+ } else {
+ motdHint := "No MOTD"
+ if dc.network != nil {
+ motdHint = "Use /motd to read the message of the day"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOMOTD,
+ Params: []string{dc.nick, motdHint},
+ })
+ }
+
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if !ch.complete {
+ continue
+ }
+ record := uc.network.channels.Value(ch.Name)
+ if record != nil && record.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)},
+ })
+
+ forwardChannel(ctx, dc, ch)
+ }
+ })
+
+ dc.forEachNetwork(func(net *network) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ // Only send history if we're the first connected client with that name
+ // for the network
+ firstClient := true
+ dc.user.forEachDownstream(func(c *downstreamConn) {
+ if c != dc && c.clientName == dc.clientName && c.network == dc.network {
+ firstClient = false
+ }
+ })
+ if firstClient {
+ net.delivered.ForEachTarget(func(target string) {
+ lastDelivered := net.delivered.LoadID(target, dc.clientName)
+ if lastDelivered == "" {
+ return
+ }
+
+ dc.sendTargetBacklog(ctx, net, target, lastDelivered)
+
+ // Fast-forward history to last message
+ targetCM := net.casemap(target)
+ lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now())
+ if err != nil {
+ dc.logger.Printf("failed to get last message ID: %v", err)
+ return
+ }
+ net.delivered.StoreID(target, dc.clientName, lastID)
+ })
+ }
+ })
+
+ return nil
+}
+
+// messageSupportsBacklog checks whether the provided message can be sent as
+// part of an history batch.
+func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool {
+ // Don't replay all messages, because that would mess up client
+ // state. For instance we just sent the list of users, sending
+ // PART messages for one of these users would be incorrect.
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ return true
+ }
+ return false
+}
+
+func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ ch := net.channels.Value(target)
+
+ ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
+ defer cancel()
+
+ targetCM := net.casemap(target)
+ history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit)
+ if err != nil {
+ dc.logger.Printf("failed to send backlog for %q: %v", target, err)
+ return
+ }
+
+ dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ if ch != nil && ch.Detached {
+ if net.detachedMessageNeedsRelay(ch, msg) {
+ dc.relayDetachedMessage(net, msg)
+ }
+ } else {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, net))
+ }
+ }
+ })
+}
+
+func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return
+ }
+
+ sender := msg.Prefix.Name
+ target, text := msg.Params[0], msg.Params[1]
+ if net.isHighlight(msg) {
+ sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ } else {
+ sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ }
+}
+
+func (dc *downstreamConn) runUntilRegistered() error {
+ ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout)
+ defer cancel()
+
+ // Close the connection with an error if the deadline is exceeded
+ go func() {
+ <-ctx.Done()
+ if err := ctx.Err(); err == context.DeadlineExceeded {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "ERROR",
+ Params: []string{"Connection registration timed out"},
+ })
+ dc.Close()
+ }
+ }()
+
+ for !dc.registered {
+ msg, err := dc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read IRC command: %w", err)
+ }
+
+ err = dc.handleMessage(ctx, msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
+ }
+ }
+
+ return nil
+}
+
+func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "PING":
+ var source, destination string
+ if err := parseMessageParams(msg, &source); err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ destination = msg.Params[1]
+ }
+ hostname := dc.srv.Config().Hostname
+ if destination != "" && destination != hostname {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHSERVER,
+ Params: []string{dc.nick, destination, "No such server"},
+ }}
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "PONG",
+ Params: []string{hostname, source},
+ })
+ return nil
+ case "PONG":
+ if len(msg.Params) == 0 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ token := msg.Params[len(msg.Params)-1]
+ dc.handlePong(token)
+ case "USER":
+ return ircError{&irc.Message{
+ Command: irc.ERR_ALREADYREGISTERED,
+ Params: []string{dc.nick, "You may not reregister"},
+ }}
+ case "NICK":
+ var rawNick string
+ if err := parseMessageParams(msg, &rawNick); err != nil {
+ return err
+ }
+
+ nick := rawNick
+ var upstream *upstreamConn
+ if dc.upstream() == nil {
+ uc, unmarshaledNick, err := dc.unmarshalEntity(nick)
+ if err == nil { // NICK nick/network: NICK only on a specific upstream
+ upstream = uc
+ nick = unmarshaledNick
+ }
+ }
+
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, rawNick, "contains illegal characters"},
+ }}
+ }
+ if casemapASCII(nick) == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"},
+ }}
+ }
+
+ var err error
+ dc.forEachNetwork(func(n *network) {
+ if err != nil || (upstream != nil && upstream.network != n) {
+ return
+ }
+ n.Nick = nick
+ err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network)
+ })
+ if err != nil {
+ return err
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ if upstream != nil && upstream != uc {
+ return
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ })
+
+ if dc.upstream() == nil && upstream == nil && dc.nick != nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ dc.nick = nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+ case "SETNAME":
+ var realname string
+ if err := parseMessageParams(msg, &realname); err != nil {
+ return err
+ }
+
+ // If the client just resets to the default, just wipe the per-network
+ // preference
+ storeRealname := realname
+ if realname == dc.user.Realname {
+ storeRealname = ""
+ }
+
+ var storeErr error
+ var needUpdate []Network
+ dc.forEachNetwork(func(n *network) {
+ // We only need to call updateNetwork for upstreams that don't
+ // support setname
+ if uc := n.conn; uc != nil && uc.caps["setname"] {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+
+ n.Realname = storeRealname
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil {
+ dc.logger.Printf("failed to store network realname: %v", err)
+ storeErr = err
+ }
+ return
+ }
+
+ record := n.Network // copy network record because we'll mutate it
+ record.Realname = storeRealname
+ needUpdate = append(needUpdate, record)
+ })
+
+ // Walk the network list as a second step, because updateNetwork
+ // mutates the original list
+ for _, record := range needUpdate {
+ if _, err := dc.user.updateNetwork(ctx, &record); err != nil {
+ dc.logger.Printf("failed to update network realname: %v", err)
+ storeErr = err
+ }
+ }
+ if storeErr != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"},
+ }}
+ }
+
+ if dc.upstream() == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+ }
+ case "JOIN":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var keys []string
+ if len(msg.Params) > 1 {
+ keys = strings.Split(msg.Params[1], ",")
+ }
+
+ for i, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ var key string
+ if len(keys) > i {
+ key = keys[i]
+ }
+
+ if !uc.isChannel(upstreamName) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{name, "Not a channel name"},
+ })
+ continue
+ }
+
+ // Most servers ignore duplicate JOIN messages. We ignore them here
+ // because some clients automatically send JOIN messages in bulk
+ // when reconnecting to the bouncer. We don't want to flood the
+ // upstream connection with these.
+ if !uc.channels.Has(upstreamName) {
+ params := []string{upstreamName}
+ if key != "" {
+ params = append(params, key)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "JOIN",
+ Params: params,
+ })
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ // Don't clear the channel key if there's one set
+ // TODO: add a way to unset the channel key
+ if key != "" {
+ ch.Key = key
+ }
+ uc.network.attach(ctx, ch)
+ } else {
+ ch = &Channel{
+ Name: upstreamName,
+ Key: key,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ }
+ case "PART":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+
+ for _, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if strings.EqualFold(reason, "detach") {
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ uc.network.detach(ch)
+ } else {
+ ch = &Channel{
+ Name: name,
+ Detached: true,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ } else {
+ params := []string{upstreamName}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "PART",
+ Params: params,
+ })
+
+ if err := uc.network.deleteChannel(ctx, upstreamName); err != nil {
+ dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err)
+ }
+ }
+ }
+ case "KICK":
+ var channelStr, userStr string
+ if err := parseMessageParams(msg, &channelStr, &userStr); err != nil {
+ return err
+ }
+
+ channels := strings.Split(channelStr, ",")
+ users := strings.Split(userStr, ",")
+
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+
+ if len(channels) != 1 && len(channels) != len(users) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_BADCHANMASK,
+ Params: []string{dc.nick, channelStr, "Bad channel mask"},
+ }}
+ }
+
+ for i, user := range users {
+ var channel string
+ if len(channels) == 1 {
+ channel = channels[0]
+ } else {
+ channel = channels[i]
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ params := []string{upstreamChannel, upstreamUser}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "KICK",
+ Params: params,
+ })
+ }
+ case "MODE":
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+
+ var modeStr string
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ if casemapASCII(name) == dc.nickCM {
+ if modeStr != "" {
+ if uc := dc.upstream(); uc != nil {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: []string{uc.nick, modeStr},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_UMODEUNKNOWNFLAG,
+ Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"},
+ })
+ }
+ } else {
+ var userMode string
+ if uc := dc.upstream(); uc != nil {
+ userMode = string(uc.modes)
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + userMode},
+ })
+ }
+ return nil
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if !uc.isChannel(upstreamName) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERSDONTMATCH,
+ Params: []string{dc.nick, "Cannot change mode for other users"},
+ }}
+ }
+
+ if modeStr != "" {
+ params := []string{upstreamName, modeStr}
+ params = append(params, msg.Params[2:]...)
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: params,
+ })
+ } else {
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "No such channel"},
+ }}
+ }
+
+ if ch.modes == nil {
+ // we haven't received the initial RPL_CHANNELMODEIS yet
+ // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS
+ return nil
+ }
+
+ modeStr, modeParams := ch.modes.Format()
+ params := []string{dc.nick, name, modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ if ch.creationTime != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, name, ch.creationTime},
+ })
+ }
+ }
+ case "TOPIC":
+ var channel string
+ if err := parseMessageParams(msg, &channel); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ if len(msg.Params) > 1 { // setting topic
+ topic := msg.Params[1]
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "TOPIC",
+ Params: []string{upstreamName, topic},
+ })
+ } else { // getting topic
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, upstreamName, "No such channel"},
+ }}
+ }
+ sendTopic(dc, ch)
+ }
+ case "LIST":
+ network := dc.network
+ if network == nil && len(msg.Params) > 0 {
+ var err error
+ network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0])
+ if err != nil {
+ return err
+ }
+ }
+ if network == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"},
+ })
+ return nil
+ }
+
+ uc := network.conn
+ if uc == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Disconnected from upstream server"},
+ })
+ return nil
+ }
+
+ uc.enqueueCommand(dc, msg)
+ case "NAMES":
+ if len(msg.Params) == 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, "*", "End of /NAMES list"},
+ })
+ return nil
+ }
+
+ channels := strings.Split(msg.Params[0], ",")
+ for _, channel := range channels {
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(upstreamName)
+ if ch != nil {
+ sendNames(dc, ch)
+ } else {
+ // NAMES on a channel we have not joined, ask upstream
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NAMES",
+ Params: []string{upstreamName},
+ })
+ }
+ }
+ // For WHOX docs, see:
+ // - http://faerion.sourceforge.net/doc/irc/whox.var
+ // - https://github.com/quakenet/snircd/blob/master/doc/readme.who
+ // Note, many features aren't widely implemented, such as flags and mask2
+ case "WHO":
+ if len(msg.Params) == 0 {
+ // TODO: support WHO without parameters
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, "*", "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // Clients will use the first mask to match RPL_ENDOFWHO
+ endOfWhoToken := msg.Params[0]
+
+ // TODO: add support for WHOX mask2
+ mask := msg.Params[0]
+ var options string
+ if len(msg.Params) > 1 {
+ options = msg.Params[1]
+ }
+
+ optionsParts := strings.SplitN(options, "%", 2)
+ // TODO: add support for WHOX flags in optionsParts[0]
+ var fields, whoxToken string
+ if len(optionsParts) == 2 {
+ optionsParts := strings.SplitN(optionsParts[1], ",", 2)
+ fields = strings.ToLower(optionsParts[0])
+ if len(optionsParts) == 2 && strings.Contains(fields, "t") {
+ whoxToken = optionsParts[1]
+ }
+ }
+
+ // TODO: support mixed bouncer/upstream WHO queries
+ maskCM := casemapASCII(mask)
+ if dc.network == nil && maskCM == dc.nickCM {
+ // TODO: support AWAY (H/G) in self WHO reply
+ flags := "H"
+ if dc.user.Admin {
+ flags += "*"
+ }
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: dc.user.Username,
+ Hostname: dc.hostname,
+ Server: dc.srv.Config().Hostname,
+ Nickname: dc.nick,
+ Flags: flags,
+ Account: dc.user.Username,
+ Realname: dc.realname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+ if maskCM == serviceNickCM {
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: servicePrefix.User,
+ Hostname: servicePrefix.Host,
+ Server: dc.srv.Config().Hostname,
+ Nickname: serviceNick,
+ Flags: "H*",
+ Account: serviceNick,
+ Realname: serviceRealname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // TODO: properly support WHO masks
+ uc, upstreamMask, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ params := []string{upstreamMask}
+ if options != "" {
+ params = append(params, options)
+ }
+
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "WHO",
+ Params: params,
+ })
+ case "WHOIS":
+ if len(msg.Params) == 0 {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NONICKNAMEGIVEN,
+ Params: []string{dc.nick, "No nickname given"},
+ }}
+ }
+
+ var target, mask string
+ if len(msg.Params) == 1 {
+ target = ""
+ mask = msg.Params[0]
+ } else {
+ target = msg.Params[0]
+ mask = msg.Params[1]
+ }
+ // TODO: support multiple WHOIS users
+ if i := strings.IndexByte(mask, ','); i >= 0 {
+ mask = mask[:i]
+ }
+
+ if dc.network == nil && casemapASCII(mask) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"},
+ })
+ if dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, dc.nick, "is a bouncer administrator"},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, dc.nick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+ if casemapASCII(mask) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, serviceNick, "is the bouncer service"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, serviceNick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+
+ // TODO: support WHOIS masks
+ uc, upstreamNick, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ var params []string
+ if target != "" {
+ if target == mask { // WHOIS nick nick
+ params = []string{upstreamNick, upstreamNick}
+ } else {
+ params = []string{target, upstreamNick}
+ }
+ } else {
+ params = []string{upstreamNick}
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "WHOIS",
+ Params: params,
+ })
+ case "PRIVMSG", "NOTICE":
+ var targetsStr, text string
+ if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) {
+ // "$" means a server mask follows. If it's the bouncer's
+ // hostname, broadcast the message to all bouncer users.
+ if !dc.user.Admin {
+ return ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_BADMASK,
+ Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"},
+ }}
+ }
+
+ dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text)
+
+ broadcastTags := tags.Copy()
+ broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ broadcastMsg := &irc.Message{
+ Tags: broadcastTags,
+ Prefix: servicePrefix,
+ Command: msg.Command,
+ Params: []string{name, text},
+ }
+ dc.srv.forEachUser(func(u *user) {
+ u.events <- eventBroadcast{broadcastMsg}
+ })
+ continue
+ }
+
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ continue
+ }
+
+ if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM {
+ if dc.caps["echo-message"] {
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ dc.SendMessage(&irc.Message{
+ Tags: echoTags,
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ }
+ handleServicePRIVMSG(ctx, dc, text)
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" {
+ dc.handleNickServPRIVMSG(ctx, uc, text)
+ }
+
+ unmarshaledText := text
+ if uc.isChannel(upstreamName) {
+ unmarshaledText = dc.unmarshalText(uc, text)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: msg.Command,
+ Params: []string{upstreamName, unmarshaledText},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: msg.Command,
+ Params: []string{upstreamName, text},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "TAGMSG":
+ var targetsStr string
+ if err := parseMessageParams(msg, &targetsStr); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: "TAGMSG",
+ Params: []string{name},
+ })
+ continue
+ }
+
+ if casemapASCII(name) == serviceNickCM {
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+ if _, ok := uc.caps["message-tags"]; !ok {
+ continue
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "INVITE":
+ var user, channel string
+ if err := parseMessageParams(msg, &user, &channel); err != nil {
+ return err
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "INVITE",
+ Params: []string{upstreamUser, upstreamChannel},
+ })
+ case "AUTHENTICATE":
+ // Post-connection-registration AUTHENTICATE is unsupported in
+ // multi-upstream mode, or if the upstream doesn't support SASL
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Upstream network authentication not supported"},
+ }}
+ }
+
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ }
+
+ if credentials != nil {
+ if uc.saslClient != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Another authentication attempt is already in progress"},
+ })
+ return nil
+ }
+
+ uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername)
+ uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword)
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"PLAIN"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ // Check number of params here, since we'll use that to save the
+ // credentials on command success
+ if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["draft/account-registration"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"},
+ }}
+ }
+
+ uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0])
+ uc.enqueueCommand(dc, msg)
+ case "MONITOR":
+ // MONITOR is unsupported in multi-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "+", "-":
+ var targets string
+ if err := parseMessageParams(msg, nil, &targets); err != nil {
+ return err
+ }
+ for _, target := range strings.Split(targets, ",") {
+ if subcommand == "+" {
+ // Hard limit, just to avoid having downstreams fill our map
+ if len(dc.monitored.innerMap) >= 1000 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_MONLISTFULL,
+ Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"},
+ })
+ continue
+ }
+
+ dc.monitored.SetValue(target, nil)
+
+ if uc.monitored.Has(target) {
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ } else {
+ dc.monitored.Delete(target)
+ }
+ }
+ uc.updateMonitor()
+ case "C": // clear
+ dc.monitored = newCasemapMap(0)
+ uc.updateMonitor()
+ case "L": // list
+ // TODO: be less lazy and pack the list
+ for _, entry := range dc.monitored.innerMap {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MONLIST,
+ Params: []string{dc.nick, entry.originalKey},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFMONLIST,
+ Params: []string{dc.nick, "End of MONITOR list"},
+ })
+ case "S": // status
+ // TODO: be less lazy and pack the lists
+ for _, entry := range dc.monitored.innerMap {
+ target := entry.originalKey
+
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ case "CHATHISTORY":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+ var target, limitStr string
+ var boundsStr [2]string
+ switch subcommand {
+ case "AFTER", "BEFORE", "LATEST":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil {
+ return err
+ }
+ case "BETWEEN":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ case "TARGETS":
+ if dc.network == nil {
+ // Either an unbound bouncer network, in which case we should return no targets,
+ // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet.
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+ if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ default:
+ // TODO: support AROUND
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"},
+ }}
+ }
+
+ // We don't save history for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+
+ store, ok := dc.user.msgStore.(chatHistoryMessageStore)
+ if !ok {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{dc.nick, "CHATHISTORY", "Unknown command"},
+ }}
+ }
+
+ network, entity, err := dc.unmarshalEntityNetwork(target)
+ if err != nil {
+ return err
+ }
+ entity = network.casemap(entity)
+
+ // TODO: support msgid criteria
+ var bounds [2]time.Time
+ bounds[0] = parseChatHistoryBound(boundsStr[0])
+ if subcommand == "LATEST" && boundsStr[0] == "*" {
+ bounds[0] = time.Now()
+ } else if bounds[0].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"},
+ }}
+ }
+
+ if boundsStr[1] != "" {
+ bounds[1] = parseChatHistoryBound(boundsStr[1])
+ if bounds[1].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"},
+ }}
+ }
+ }
+
+ limit, err := strconv.Atoi(limitStr)
+ if err != nil || limit < 0 || limit > chatHistoryLimit {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"},
+ }}
+ }
+
+ eventPlayback := dc.caps["draft/event-playback"]
+
+ var history []*irc.Message
+ switch subcommand {
+ case "BEFORE", "LATEST":
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback)
+ case "AFTER":
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback)
+ case "BETWEEN":
+ if bounds[0].Before(bounds[1]) {
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ } else {
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ }
+ case "TARGETS":
+ // TODO: support TARGETS in multi-upstream mode
+ targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback)
+ if err != nil {
+ dc.logger.Printf("failed fetching targets for chathistory: %v", err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"},
+ }}
+ }
+
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {
+ for _, target := range targets {
+ if ch := network.channels.Value(target.Name); ch != nil && ch.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "CHATHISTORY",
+ Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)},
+ })
+ }
+ })
+
+ return nil
+ }
+ if err != nil {
+ dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err)
+ return newChatHistoryError(subcommand, target)
+ }
+
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, network))
+ }
+ })
+ case "READ":
+ var target, criteria string
+ if err := parseMessageParams(msg, &target); err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"},
+ }}
+ }
+ if len(msg.Params) > 1 {
+ criteria = msg.Params[1]
+ }
+
+ // We don't save read receipts for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{target, "*"},
+ })
+ return nil
+ }
+
+ uc, entity, err := dc.unmarshalEntity(target)
+ if err != nil {
+ return err
+ }
+ entityCM := uc.network.casemap(entity)
+
+ r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ } else if r == nil {
+ r = &ReadReceipt{
+ Target: entityCM,
+ }
+ }
+
+ broadcast := false
+ if len(criteria) > 0 {
+ // TODO: support msgid criteria
+ criteriaParts := strings.SplitN(criteria, "=", 2)
+ if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"},
+ }}
+ }
+
+ timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1])
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"},
+ }}
+ }
+ now := time.Now()
+ if timestamp.After(now) {
+ timestamp = now
+ }
+ if r.Timestamp.Before(timestamp) {
+ r.Timestamp = timestamp
+ if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil {
+ dc.logger.Printf("failed to store receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ }
+ broadcast = true
+ }
+ }
+
+ timestampStr := "*"
+ if !r.Timestamp.IsZero() {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ uc.forEachDownstream(func(d *downstreamConn) {
+ if broadcast || dc.id == d.id {
+ d.SendMessage(&irc.Message{
+ Prefix: d.prefix(),
+ Command: "READ",
+ Params: []string{d.marshalEntity(uc.network, entity), timestampStr},
+ })
+ }
+ })
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"},
+ }}
+ case "LISTNETWORKS":
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ case "ADDNETWORK":
+ var attrsStr string
+ if err := parseMessageParams(msg, nil, &attrsStr); err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ record := &Network{Nick: dc.nick, Enabled: true}
+ if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)},
+ })
+ case "CHANGENETWORK":
+ var idStr, attrsStr string
+ if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ _, err = dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"CHANGENETWORK", idStr},
+ })
+ case "DELNETWORK":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"DELNETWORK", idStr},
+ })
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"},
+ }}
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+
+ // Only forward unknown commands in single-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, msg)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) {
+ username, password, ok := parseNickServCredentials(text, uc.nick)
+ if ok {
+ uc.network.autoSaveSASLPlain(ctx, username, password)
+ }
+}
+
+func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
+ fields := strings.Fields(text)
+ if len(fields) < 2 {
+ return "", "", false
+ }
+ cmd := strings.ToUpper(fields[0])
+ params := fields[1:]
+ switch cmd {
+ case "REGISTER":
+ username = nick
+ password = params[0]
+ case "IDENTIFY":
+ if len(params) == 1 {
+ username = nick
+ password = params[0]
+ } else {
+ username = params[0]
+ password = params[1]
+ }
+ case "SET":
+ if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") {
+ username = nick
+ password = params[1]
+ }
+ default:
+ return "", "", false
+ }
+ return username, password, true
+}
--- /dev/null
+module marisa.chaotic.ninja/suika
+
+go 1.20
+
+require (
+ git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99
+ git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9
+ github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead
+ github.com/lib/pq v1.10.7
+ golang.org/x/crypto v0.7.0
+ golang.org/x/term v0.6.0
+ golang.org/x/time v0.3.0
+ gopkg.in/irc.v3 v3.1.4
+ modernc.org/sqlite v1.21.0
+)
+
+require (
+ github.com/dustin/go-humanize v1.0.0 // indirect
+ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
+ github.com/google/uuid v1.3.0 // indirect
+ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
+ github.com/mattn/go-isatty v0.0.16 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/stretchr/testify v1.8.0 // indirect
+ golang.org/x/mod v0.3.0 // indirect
+ golang.org/x/sys v0.6.0 // indirect
+ golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect
+ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+ gopkg.in/yaml.v2 v2.4.0 // indirect
+ lukechampine.com/uint128 v1.2.0 // indirect
+ modernc.org/cc/v3 v3.40.0 // indirect
+ modernc.org/ccgo/v3 v3.16.13 // indirect
+ modernc.org/libc v1.22.3 // indirect
+ modernc.org/mathutil v1.5.0 // indirect
+ modernc.org/memory v1.5.0 // indirect
+ modernc.org/opt v0.1.3 // indirect
+ modernc.org/strutil v1.1.3 // indirect
+ modernc.org/token v1.0.1 // indirect
+)
--- /dev/null
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao=
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U=
+git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
+github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
+github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
+github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
+github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
+golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
+golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
+golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
+golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
+golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc=
+gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI=
+lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
+modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw=
+modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0=
+modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw=
+modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY=
+modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk=
+modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM=
+modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY=
+modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw=
+modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
+modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
+modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
+modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
+modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
+modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
+modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow=
+modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI=
+modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
+modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw=
+modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws=
+modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg=
+modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
+modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE=
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+ "unicode"
+ "unicode/utf8"
+
+ "gopkg.in/irc.v3"
+)
+
+const (
+ rpl_statsping = "246"
+ rpl_localusers = "265"
+ rpl_globalusers = "266"
+ rpl_creationtime = "329"
+ rpl_topicwhotime = "333"
+ rpl_whospcrpl = "354"
+ rpl_whoisaccount = "330"
+ err_invalidcapcmd = "410"
+)
+
+const (
+ maxMessageLength = 512
+ maxMessageParams = 15
+ maxSASLLength = 400
+)
+
+// The server-time layout, as defined in the IRCv3 spec.
+const serverTimeLayout = "2006-01-02T15:04:05.000Z"
+
+func formatServerTime(t time.Time) string {
+ return t.UTC().Format(serverTimeLayout)
+}
+
+type userModes string
+
+func (ms userModes) Has(c byte) bool {
+ return strings.IndexByte(string(ms), c) >= 0
+}
+
+func (ms *userModes) Add(c byte) {
+ if !ms.Has(c) {
+ *ms += userModes(c)
+ }
+}
+
+func (ms *userModes) Del(c byte) {
+ i := strings.IndexByte(string(*ms), c)
+ if i >= 0 {
+ *ms = (*ms)[:i] + (*ms)[i+1:]
+ }
+}
+
+func (ms *userModes) Apply(s string) error {
+ var plusMinus byte
+ for i := 0; i < len(s); i++ {
+ switch c := s[i]; c {
+ case '+', '-':
+ plusMinus = c
+ default:
+ switch plusMinus {
+ case '+':
+ ms.Add(c)
+ case '-':
+ ms.Del(c)
+ default:
+ return fmt.Errorf("malformed modestring %q: missing plus/minus", s)
+ }
+ }
+ }
+ return nil
+}
+
+type channelModeType byte
+
+// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message
+const (
+ // modes that add or remove an address to or from a list
+ modeTypeA channelModeType = iota
+ // modes that change a setting on a channel, and must always have a parameter
+ modeTypeB
+ // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset
+ modeTypeC
+ // modes that change a setting on a channel, and must not have a parameter
+ modeTypeD
+)
+
+var stdChannelModes = map[byte]channelModeType{
+ 'b': modeTypeA, // ban list
+ 'e': modeTypeA, // ban exception list
+ 'I': modeTypeA, // invite exception list
+ 'k': modeTypeB, // channel key
+ 'l': modeTypeC, // channel user limit
+ 'i': modeTypeD, // channel is invite-only
+ 'm': modeTypeD, // channel is moderated
+ 'n': modeTypeD, // channel has no external messages
+ 's': modeTypeD, // channel is secret
+ 't': modeTypeD, // channel has protected topic
+}
+
+type channelModes map[byte]string
+
+// applyChannelModes parses a mode string and mode arguments from a MODE message,
+// and applies the corresponding channel mode and user membership changes on that channel.
+//
+// If ch.modes is nil, channel modes are not updated.
+//
+// needMarshaling is a list of indexes of mode arguments that represent entities
+// that must be marshaled when sent downstream.
+func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) {
+ needMarshaling = make(map[int]struct{}, len(arguments))
+ nextArgument := 0
+ var plusMinus byte
+outer:
+ for i := 0; i < len(modeStr); i++ {
+ mode := modeStr[i]
+ if mode == '+' || mode == '-' {
+ plusMinus = mode
+ continue
+ }
+ if plusMinus != '+' && plusMinus != '-' {
+ return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr)
+ }
+
+ for _, membership := range ch.conn.availableMemberships {
+ if membership.Mode == mode {
+ if nextArgument >= len(arguments) {
+ return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
+ }
+ member := arguments[nextArgument]
+ m := ch.Members.Value(member)
+ if m != nil {
+ if plusMinus == '+' {
+ m.Add(ch.conn.availableMemberships, membership)
+ } else {
+ // TODO: for upstreams without multi-prefix, query the user modes again
+ m.Remove(membership)
+ }
+ }
+ needMarshaling[nextArgument] = struct{}{}
+ nextArgument++
+ continue outer
+ }
+ }
+
+ mt, ok := ch.conn.availableChannelModes[mode]
+ if !ok {
+ continue
+ }
+ if mt == modeTypeA {
+ nextArgument++
+ } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') {
+ if plusMinus == '+' {
+ var argument string
+ // some sentitive arguments (such as channel keys) can be omitted for privacy
+ // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages)
+ if nextArgument < len(arguments) {
+ argument = arguments[nextArgument]
+ }
+ if ch.modes != nil {
+ ch.modes[mode] = argument
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ nextArgument++
+ } else if mt == modeTypeC || mt == modeTypeD {
+ if plusMinus == '+' {
+ if ch.modes != nil {
+ ch.modes[mode] = ""
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ }
+ }
+ return needMarshaling, nil
+}
+
+func (cm channelModes) Format() (modeString string, parameters []string) {
+ var modesWithValues strings.Builder
+ var modesWithoutValues strings.Builder
+ parameters = make([]string, 0, 16)
+ for mode, value := range cm {
+ if value != "" {
+ modesWithValues.WriteString(string(mode))
+ parameters = append(parameters, value)
+ } else {
+ modesWithoutValues.WriteString(string(mode))
+ }
+ }
+ modeString = "+" + modesWithValues.String() + modesWithoutValues.String()
+ return
+}
+
+const stdChannelTypes = "#&+!"
+
+type channelStatus byte
+
+const (
+ channelPublic channelStatus = '='
+ channelSecret channelStatus = '@'
+ channelPrivate channelStatus = '*'
+)
+
+func parseChannelStatus(s string) (channelStatus, error) {
+ if len(s) > 1 {
+ return 0, fmt.Errorf("invalid channel status %q: more than one character", s)
+ }
+ switch cs := channelStatus(s[0]); cs {
+ case channelPublic, channelSecret, channelPrivate:
+ return cs, nil
+ default:
+ return 0, fmt.Errorf("invalid channel status %q: unknown status", s)
+ }
+}
+
+type membership struct {
+ Mode byte
+ Prefix byte
+}
+
+var stdMemberships = []membership{
+ {'q', '~'}, // founder
+ {'a', '&'}, // protected
+ {'o', '@'}, // operator
+ {'h', '%'}, // halfop
+ {'v', '+'}, // voice
+}
+
+// memberships always sorted by descending membership rank
+type memberships []membership
+
+func (m *memberships) Add(availableMemberships []membership, newMembership membership) {
+ l := *m
+ i := 0
+ for _, availableMembership := range availableMemberships {
+ if i >= len(l) {
+ break
+ }
+ if l[i] == availableMembership {
+ if availableMembership == newMembership {
+ // we already have this membership
+ return
+ }
+ i++
+ continue
+ }
+ if availableMembership == newMembership {
+ break
+ }
+ }
+ // insert newMembership at i
+ l = append(l, membership{})
+ copy(l[i+1:], l[i:])
+ l[i] = newMembership
+ *m = l
+}
+
+func (m *memberships) Remove(oldMembership membership) {
+ l := *m
+ for i, currentMembership := range l {
+ if currentMembership == oldMembership {
+ *m = append(l[:i], l[i+1:]...)
+ return
+ }
+ }
+}
+
+func (m memberships) Format(dc *downstreamConn) string {
+ if !dc.caps["multi-prefix"] {
+ if len(m) == 0 {
+ return ""
+ }
+ return string(m[0].Prefix)
+ }
+ prefixes := make([]byte, len(m))
+ for i, membership := range m {
+ prefixes[i] = membership.Prefix
+ }
+ return string(prefixes)
+}
+
+func parseMessageParams(msg *irc.Message, out ...*string) error {
+ if len(msg.Params) < len(out) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ for i := range out {
+ if out[i] != nil {
+ *out[i] = msg.Params[i]
+ }
+ }
+ return nil
+}
+
+func copyClientTags(tags irc.Tags) irc.Tags {
+ t := make(irc.Tags, len(tags))
+ for k, v := range tags {
+ if strings.HasPrefix(k, "+") {
+ t[k] = v
+ }
+ }
+ return t
+}
+
+type batch struct {
+ Type string
+ Params []string
+ Outer *batch // if not-nil, this batch is nested in Outer
+ Label string
+}
+
+func join(channels, keys []string) []*irc.Message {
+ // Put channels with a key first
+ js := joinSorter{channels, keys}
+ sort.Sort(&js)
+
+ // Two spaces because there are three words (JOIN, channels and keys)
+ maxLength := maxMessageLength - (len("JOIN") + 2)
+
+ var msgs []*irc.Message
+ var channelsBuf, keysBuf strings.Builder
+ for i, channel := range channels {
+ key := keys[i]
+
+ n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel)
+ if key != "" {
+ n += 1 + len(key)
+ }
+
+ if channelsBuf.Len() > 0 && n > maxLength {
+ // No room for the new channel in this message
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ channelsBuf.Reset()
+ keysBuf.Reset()
+ }
+
+ if channelsBuf.Len() > 0 {
+ channelsBuf.WriteByte(',')
+ }
+ channelsBuf.WriteString(channel)
+ if key != "" {
+ if keysBuf.Len() > 0 {
+ keysBuf.WriteByte(',')
+ }
+ keysBuf.WriteString(key)
+ }
+ }
+ if channelsBuf.Len() > 0 {
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ }
+
+ return msgs
+}
+
+func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message {
+ maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text
+
+ var msgs []*irc.Message
+ for len(tokens) > 0 {
+ var msgTokens []string
+ if len(tokens) > maxTokens {
+ msgTokens = tokens[:maxTokens]
+ tokens = tokens[maxTokens:]
+ } else {
+ msgTokens = tokens
+ tokens = nil
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ISUPPORT,
+ Params: append(append([]string{nick}, msgTokens...), "are supported"),
+ })
+ }
+
+ return msgs
+}
+
+func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message {
+ var msgs []*irc.Message
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTDSTART,
+ Params: []string{nick, fmt.Sprintf("- Message of the Day -")},
+ })
+
+ for _, l := range strings.Split(motd, "\n") {
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTD,
+ Params: []string{nick, l},
+ })
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ENDOFMOTD,
+ Params: []string{nick, "End of /MOTD command."},
+ })
+
+ return msgs
+}
+
+func generateMonitor(subcmd string, targets []string) []*irc.Message {
+ maxLength := maxMessageLength - len("MONITOR "+subcmd+" ")
+
+ var msgs []*irc.Message
+ var buf []string
+ n := 0
+ for _, target := range targets {
+ if n+len(target)+1 > maxLength {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ buf = buf[:0]
+ n = 0
+ }
+
+ buf = append(buf, target)
+ n += len(target) + 1
+ }
+
+ if len(buf) > 0 {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ }
+
+ return msgs
+}
+
+type joinSorter struct {
+ channels []string
+ keys []string
+}
+
+func (js *joinSorter) Len() int {
+ return len(js.channels)
+}
+
+func (js *joinSorter) Less(i, j int) bool {
+ if (js.keys[i] != "") != (js.keys[j] != "") {
+ // Only one of the channels has a key
+ return js.keys[i] != ""
+ }
+ return js.channels[i] < js.channels[j]
+}
+
+func (js *joinSorter) Swap(i, j int) {
+ js.channels[i], js.channels[j] = js.channels[j], js.channels[i]
+ js.keys[i], js.keys[j] = js.keys[j], js.keys[i]
+}
+
+// parseCTCPMessage parses a CTCP message. CTCP is defined in
+// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02
+func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
+ if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 {
+ return "", "", false
+ }
+ text := msg.Params[1]
+
+ if !strings.HasPrefix(text, "\x01") {
+ return "", "", false
+ }
+ text = strings.Trim(text, "\x01")
+
+ words := strings.SplitN(text, " ", 2)
+ cmd = strings.ToUpper(words[0])
+ if len(words) > 1 {
+ params = words[1]
+ }
+
+ return cmd, params, true
+}
+
+type casemapping func(string) string
+
+func casemapNone(name string) string {
+ return name
+}
+
+// CasemapASCII of name is the canonical representation of name according to the
+// ascii casemapping.
+func casemapASCII(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459 of name is the canonical representation of name according to the
+// rfc1459 casemapping.
+func casemapRFC1459(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ } else if r == '~' {
+ nameBytes[i] = '^'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459Strict of name is the canonical representation of name
+// according to the rfc1459-strict casemapping.
+func casemapRFC1459Strict(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ }
+ }
+ return string(nameBytes)
+}
+
+func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
+ switch tokenValue {
+ case "ascii":
+ casemap = casemapASCII
+ case "rfc1459":
+ casemap = casemapRFC1459
+ case "rfc1459-strict":
+ casemap = casemapRFC1459Strict
+ default:
+ return nil, false
+ }
+ return casemap, true
+}
+
+func partialCasemap(higher casemapping, name string) string {
+ nameFullyCM := []byte(higher(name))
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') {
+ nameBytes[i] = nameFullyCM[i]
+ }
+ }
+ return string(nameBytes)
+}
+
+type casemapMap struct {
+ innerMap map[string]casemapEntry
+ casemap casemapping
+}
+
+type casemapEntry struct {
+ originalKey string
+ value interface{}
+}
+
+func newCasemapMap(size int) casemapMap {
+ return casemapMap{
+ innerMap: make(map[string]casemapEntry, size),
+ casemap: casemapNone,
+ }
+}
+
+func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return "", false
+ }
+ return entry.originalKey, true
+}
+
+func (cm *casemapMap) Has(name string) bool {
+ _, ok := cm.innerMap[cm.casemap(name)]
+ return ok
+}
+
+func (cm *casemapMap) Len() int {
+ return len(cm.innerMap)
+}
+
+func (cm *casemapMap) SetValue(name string, value interface{}) {
+ nameCM := cm.casemap(name)
+ entry, ok := cm.innerMap[nameCM]
+ if !ok {
+ cm.innerMap[nameCM] = casemapEntry{
+ originalKey: name,
+ value: value,
+ }
+ return
+ }
+ entry.value = value
+ cm.innerMap[nameCM] = entry
+}
+
+func (cm *casemapMap) Delete(name string) {
+ delete(cm.innerMap, cm.casemap(name))
+}
+
+func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
+ cm.casemap = newCasemap
+ newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
+ for _, entry := range cm.innerMap {
+ newInnerMap[cm.casemap(entry.originalKey)] = entry
+ }
+ cm.innerMap = newInnerMap
+}
+
+type upstreamChannelCasemapMap struct{ casemapMap }
+
+func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*upstreamChannel)
+}
+
+type channelCasemapMap struct{ casemapMap }
+
+func (cm *channelCasemapMap) Value(name string) *Channel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*Channel)
+}
+
+type membershipsCasemapMap struct{ casemapMap }
+
+func (cm *membershipsCasemapMap) Value(name string) *memberships {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*memberships)
+}
+
+type deliveredCasemapMap struct{ casemapMap }
+
+func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(deliveredClientMap)
+}
+
+type monitorCasemapMap struct{ casemapMap }
+
+func (cm *monitorCasemapMap) Value(name string) (online bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return false
+ }
+ return entry.value.(bool)
+}
+
+func isWordBoundary(r rune) bool {
+ switch r {
+ case '-', '_', '|': // inspired from weechat.look.highlight_regex
+ return false
+ default:
+ return !unicode.IsLetter(r) && !unicode.IsNumber(r)
+ }
+}
+
+func isHighlight(text, nick string) bool {
+ for {
+ i := strings.Index(text, nick)
+ if i < 0 {
+ return false
+ }
+
+ left, _ := utf8.DecodeLastRuneInString(text[:i])
+ right, _ := utf8.DecodeRuneInString(text[i+len(nick):])
+ if isWordBoundary(left) && isWordBoundary(right) {
+ return true
+ }
+
+ text = text[i+len(nick):]
+ }
+}
+
+// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound.
+// The zero time is returned on error.
+func parseChatHistoryBound(param string) time.Time {
+ parts := strings.SplitN(param, "=", 2)
+ if len(parts) != 2 {
+ return time.Time{}
+ }
+ switch parts[0] {
+ case "timestamp":
+ timestamp, err := time.Parse(serverTimeLayout, parts[1])
+ if err != nil {
+ return time.Time{}
+ }
+ return timestamp
+ default:
+ return time.Time{}
+ }
+}
+
+// whoxFields is the list of all WHOX field letters, by order of appearance in
+// RPL_WHOSPCRPL messages.
+var whoxFields = []byte("tcuihsnfdlaor")
+
+type whoxInfo struct {
+ Token string
+ Username string
+ Hostname string
+ Server string
+ Nickname string
+ Flags string
+ Account string
+ Realname string
+}
+
+func (info *whoxInfo) get(field byte) string {
+ switch field {
+ case 't':
+ return info.Token
+ case 'c':
+ return "*"
+ case 'u':
+ return info.Username
+ case 'i':
+ return "255.255.255.255"
+ case 'h':
+ return info.Hostname
+ case 's':
+ return info.Server
+ case 'n':
+ return info.Nickname
+ case 'f':
+ return info.Flags
+ case 'd':
+ return "0"
+ case 'l': // idle time
+ return "0"
+ case 'a':
+ account := "0" // WHOX uses "0" to mean "no account"
+ if info.Account != "" && info.Account != "*" {
+ account = info.Account
+ }
+ return account
+ case 'o':
+ return "0"
+ case 'r':
+ return info.Realname
+ }
+ return ""
+}
+
+func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message {
+ if fields == "" {
+ return &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname},
+ }
+ }
+
+ fieldSet := make(map[byte]bool)
+ for i := 0; i < len(fields); i++ {
+ fieldSet[fields[i]] = true
+ }
+
+ var values []string
+ for _, field := range whoxFields {
+ if !fieldSet[field] {
+ continue
+ }
+ values = append(values, info.get(field))
+ }
+
+ return &irc.Message{
+ Prefix: prefix,
+ Command: rpl_whospcrpl,
+ Params: append([]string{nick}, values...),
+ }
+}
+
+var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C")
+
+func encodeISUPPORT(s string) string {
+ return isupportEncoder.Replace(s)
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func TestIsHighlight(t *testing.T) {
+ nick := "SojuUser"
+ testCases := []struct {
+ name string
+ text string
+ hl bool
+ }{
+ {"noContains", "hi there Soju User!", false},
+ {"middle", "hi there SojuUser!", true},
+ {"start", "SojuUser: how are you doing?", true},
+ {"end", "maybe ask SojuUser", true},
+ {"inWord", "but OtherSojuUserSan is a different nick", false},
+ {"startWord", "and OtherSojuUser is another different nick", false},
+ {"endWord", "and SojuUserSan is yet a different nick", false},
+ {"underscore", "and SojuUser_san has nothing to do with me", false},
+ {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false},
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ hl := isHighlight(tc.text, nick)
+ if hl != tc.hl {
+ t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl)
+ }
+ })
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+// messageStore is a per-user store for IRC messages.
+type messageStore interface {
+ Close() error
+ // LastMsgID queries the last message ID for the given network, entity and
+ // date. The message ID returned may not refer to a valid message, but can be
+ // used in history queries.
+ LastMsgID(network *Network, entity string, t time.Time) (string, error)
+ // LoadLatestID queries the latest non-event messages for the given network,
+ // entity and date, up to a count of limit messages, sorted from oldest to newest.
+ LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
+ Append(network *Network, entity string, msg *irc.Message) (id string, err error)
+}
+
+type chatHistoryTarget struct {
+ Name string
+ LatestMessage time.Time
+}
+
+// chatHistoryMessageStore is a message store that supports chat history
+// operations.
+type chatHistoryMessageStore interface {
+ messageStore
+
+ // ListTargets lists channels and nicknames by time of the latest message.
+ // It returns up to limit targets, starting from start and ending on end,
+ // both excluded. end may be before or after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
+ // LoadBeforeTime loads up to limit messages before start down to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is before start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+ // LoadBeforeTime loads up to limit messages after start up to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+}
+
+type msgIDType uint
+
+const (
+ msgIDNone msgIDType = iota
+ msgIDMemory
+ msgIDFS
+)
+
+const msgIDVersion uint = 0
+
+type msgIDHeader struct {
+ Version uint
+ Network bare.Int
+ Target string
+ Type msgIDType
+}
+
+type msgIDBody interface {
+ msgIDType() msgIDType
+}
+
+func formatMsgID(netID int64, target string, body msgIDBody) string {
+ var buf bytes.Buffer
+ w := bare.NewWriter(&buf)
+
+ header := msgIDHeader{
+ Version: msgIDVersion,
+ Network: bare.Int(netID),
+ Target: target,
+ Type: body.msgIDType(),
+ }
+ if err := bare.MarshalWriter(w, &header); err != nil {
+ panic(err)
+ }
+ if err := bare.MarshalWriter(w, body); err != nil {
+ panic(err)
+ }
+ return base64.RawURLEncoding.EncodeToString(buf.Bytes())
+}
+
+func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) {
+ b, err := base64.RawURLEncoding.DecodeString(s)
+ if err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ r := bare.NewReader(bytes.NewReader(b))
+
+ var header msgIDHeader
+ if err := bare.UnmarshalBareReader(r, &header); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ if header.Version != msgIDVersion {
+ return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion)
+ }
+
+ if body != nil {
+ typ := body.msgIDType()
+ if header.Type != typ {
+ return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ)
+ }
+
+ if err := bare.UnmarshalBareReader(r, body); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+ }
+
+ return int64(header.Network), header.Target, nil
+}
--- /dev/null
+package suika
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ fsMessageStoreMaxFiles = 20
+ fsMessageStoreMaxTries = 100
+)
+
+func escapeFilename(unsafe string) (safe string) {
+ if unsafe == "." {
+ return "-"
+ } else if unsafe == ".." {
+ return "--"
+ } else {
+ return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe)
+ }
+}
+
+type date struct {
+ Year, Month, Day int
+}
+
+func newDate(t time.Time) date {
+ year, month, day := t.Date()
+ return date{year, int(month), day}
+}
+
+func (d date) Time() time.Time {
+ return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local)
+}
+
+type fsMsgID struct {
+ Date date
+ Offset bare.Int
+}
+
+func (fsMsgID) msgIDType() msgIDType {
+ return msgIDFS
+}
+
+func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) {
+ var id fsMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", time.Time{}, 0, err
+ }
+ return netID, entity, id.Date.Time(), int64(id.Offset), nil
+}
+
+func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string {
+ id := fsMsgID{
+ Date: newDate(t),
+ Offset: bare.Int(offset),
+ }
+ return formatMsgID(netID, entity, &id)
+}
+
+type fsMessageStoreFile struct {
+ *os.File
+ lastUse time.Time
+}
+
+// fsMessageStore is a per-user on-disk store for IRC messages.
+//
+// It mimicks the ZNC log layout and format. See the ZNC source:
+// https://github.com/znc/znc/blob/master/modules/log.cpp
+type fsMessageStore struct {
+ root string
+ user *User
+
+ // Write-only files used by Append
+ files map[string]*fsMessageStoreFile // indexed by entity
+}
+
+var _ messageStore = (*fsMessageStore)(nil)
+var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
+
+func newFSMessageStore(root string, user *User) *fsMessageStore {
+ return &fsMessageStore{
+ root: filepath.Join(root, escapeFilename(user.Username)),
+ user: user,
+ files: make(map[string]*fsMessageStoreFile),
+ }
+}
+
+func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
+ year, month, day := t.Date()
+ filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
+ return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
+}
+
+// nextMsgID queries the message ID for the next message to be written to f.
+func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
+ offset, err := f.Seek(0, io.SeekEnd)
+ if err != nil {
+ return "", fmt.Errorf("failed to query next FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, offset), nil
+}
+
+func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ p := ms.logPath(network, entity, t)
+ fi, err := os.Stat(p)
+ if os.IsNotExist(err) {
+ return formatFSMsgID(network.ID, entity, t, -1), nil
+ } else if err != nil {
+ return "", fmt.Errorf("failed to query last FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
+}
+
+func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ s := formatMessage(msg)
+ if s == "" {
+ return "", nil
+ }
+
+ var t time.Time
+ if tag, ok := msg.Tags["time"]; ok {
+ var err error
+ t, err = time.Parse(serverTimeLayout, string(tag))
+ if err != nil {
+ return "", fmt.Errorf("failed to parse message time tag: %v", err)
+ }
+ t = t.In(time.Local)
+ } else {
+ t = time.Now()
+ }
+
+ f := ms.files[entity]
+
+ // TODO: handle non-monotonic clock behaviour
+ path := ms.logPath(network, entity, t)
+ if f == nil || f.Name() != path {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0750); err != nil {
+ return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err)
+ }
+
+ ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640)
+ if err != nil {
+ return "", fmt.Errorf("failed to open message log file %q: %v", path, err)
+ }
+
+ if f != nil {
+ f.Close()
+ }
+ f = &fsMessageStoreFile{File: ff}
+ ms.files[entity] = f
+ }
+
+ f.lastUse = time.Now()
+
+ if len(ms.files) > fsMessageStoreMaxFiles {
+ entities := make([]string, 0, len(ms.files))
+ for name := range ms.files {
+ entities = append(entities, name)
+ }
+ sort.Slice(entities, func(i, j int) bool {
+ a, b := entities[i], entities[j]
+ return ms.files[a].lastUse.Before(ms.files[b].lastUse)
+ })
+ entities = entities[0 : len(entities)-fsMessageStoreMaxFiles]
+ for _, name := range entities {
+ ms.files[name].Close()
+ delete(ms.files, name)
+ }
+ }
+
+ msgID, err := nextFSMsgID(network, entity, t, f.File)
+ if err != nil {
+ return "", fmt.Errorf("failed to generate message ID: %v", err)
+ }
+
+ _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
+ if err != nil {
+ return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err)
+ }
+
+ return msgID, nil
+}
+
+func (ms *fsMessageStore) Close() error {
+ var closeErr error
+ for _, f := range ms.files {
+ if err := f.Close(); err != nil {
+ closeErr = fmt.Errorf("failed to close message store: %v", err)
+ }
+ }
+ return closeErr
+}
+
+// formatMessage formats a message log line. It assumes a well-formed IRC
+// message.
+func formatMessage(msg *irc.Message) string {
+ switch strings.ToUpper(msg.Command) {
+ case "NICK":
+ return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0])
+ case "JOIN":
+ return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host)
+ case "PART":
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+ return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "KICK":
+ nick := msg.Params[1]
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+ return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason)
+ case "QUIT":
+ var reason string
+ if len(msg.Params) > 0 {
+ reason = msg.Params[0]
+ }
+ return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "TOPIC":
+ var topic string
+ if len(msg.Params) > 1 {
+ topic = msg.Params[1]
+ }
+ return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic)
+ case "MODE":
+ return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " "))
+ case "NOTICE":
+ return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1])
+ case "PRIVMSG":
+ if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" {
+ return fmt.Sprintf("* %s %s", msg.Prefix.Name, params)
+ } else {
+ return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1])
+ }
+ default:
+ return ""
+ }
+}
+
+func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
+ var hour, minute, second int
+ _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
+ if err != nil {
+ return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err)
+ }
+ line = line[11:]
+
+ var cmd string
+ var prefix *irc.Prefix
+ var params []string
+ if events && strings.HasPrefix(line, "*** ") {
+ parts := strings.SplitN(line[4:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ switch parts[0] {
+ case "Joins:", "Parts:", "Quits:":
+ args := strings.SplitN(parts[1], " ", 3)
+ if len(args) < 2 {
+ return nil, time.Time{}, nil
+ }
+ nick := args[0]
+ mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ maskParts := strings.SplitN(mask, "@", 2)
+ if len(maskParts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ prefix = &irc.Prefix{
+ Name: nick,
+ User: maskParts[0],
+ Host: maskParts[1],
+ }
+ var reason string
+ if len(args) > 2 {
+ reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")")
+ }
+ switch parts[0] {
+ case "Joins:":
+ cmd = "JOIN"
+ params = []string{entity}
+ case "Parts:":
+ cmd = "PART"
+ if reason != "" {
+ params = []string{entity, reason}
+ } else {
+ params = []string{entity}
+ }
+ case "Quits:":
+ cmd = "QUIT"
+ if reason != "" {
+ params = []string{reason}
+ }
+ }
+ default:
+ nick := parts[0]
+ rem := parts[1]
+ if r := strings.TrimPrefix(rem, "is now known as "); r != rem {
+ cmd = "NICK"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = []string{r}
+ } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem {
+ args := strings.SplitN(r, " ", 2)
+ if len(args) != 2 {
+ return nil, time.Time{}, nil
+ }
+ cmd = "KICK"
+ prefix = &irc.Prefix{
+ Name: args[0],
+ }
+ reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ params = []string{entity, nick}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem {
+ cmd = "TOPIC"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'")
+ params = []string{entity, topic}
+ } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem {
+ cmd = "MODE"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = append([]string{entity}, strings.Split(r, " ")...)
+ } else {
+ return nil, time.Time{}, nil
+ }
+ }
+ } else {
+ var sender, text string
+ if strings.HasPrefix(line, "<") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[1:], "> ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "-") {
+ cmd = "NOTICE"
+ parts := strings.SplitN(line[1:], "- ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "* ") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[2:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01"
+ } else {
+ return nil, time.Time{}, nil
+ }
+
+ prefix = &irc.Prefix{Name: sender}
+ if entity == sender {
+ // This is a direct message from a user to us. We don't store own
+ // our nickname in the logs, so grab it from the network settings.
+ // Not very accurate since this may not match our nick at the time
+ // the message was received, but we can't do a lot better.
+ entity = GetNick(ms.user, network)
+ }
+ params = []string{entity, text}
+ }
+
+ year, month, day := ref.Date()
+ t := time.Date(year, month, day, hour, minute, second, 0, time.Local)
+
+ msg := &irc.Message{
+ Tags: map[string]irc.TagValue{
+ "time": irc.TagValue(formatServerTime(t)),
+ },
+ Prefix: prefix,
+ Command: cmd,
+ Params: params,
+ }
+ return msg, t, nil
+}
+
+func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages before ref: %v", err)
+ }
+ defer f.Close()
+
+ historyRing := make([]*irc.Message, limit)
+ cur := 0
+
+ sc := bufio.NewScanner(f)
+
+ if afterOffset >= 0 {
+ if _, err := f.Seek(afterOffset, io.SeekStart); err != nil {
+ return nil, nil
+ }
+ sc.Scan() // skip till next newline
+ }
+
+ for sc.Scan() {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(end) {
+ continue
+ } else if !t.Before(ref) {
+ break
+ }
+
+ historyRing[cur%limit] = msg
+ cur++
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err())
+ }
+
+ n := limit
+ if cur < limit {
+ n = cur
+ }
+ start := (cur - n + limit) % limit
+
+ if start+n <= limit { // ring doesnt wrap
+ return historyRing[start : start+n], nil
+ } else { // ring wraps
+ history := make([]*irc.Message, n)
+ r := copy(history, historyRing[start:])
+ copy(history[r:], historyRing[:n-r])
+ return history, nil
+ }
+}
+
+func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages after ref: %v", err)
+ }
+ defer f.Close()
+
+ var history []*irc.Message
+ sc := bufio.NewScanner(f)
+ for sc.Scan() && len(history) < limit {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(ref) {
+ continue
+ } else if !t.Before(end) {
+ break
+ }
+
+ history = append(history, msg)
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err())
+ }
+
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ history := make([]*irc.Message, limit)
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) {
+ buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ var history []*irc.Message
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) {
+ buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ history = append(history, buf...)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location())
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ var afterTime time.Time
+ var afterOffset int64
+ if id != "" {
+ var idNet int64
+ var idEntity string
+ var err error
+ idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+ if idNet != network.ID || idEntity != entity {
+ return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity")
+ }
+ }
+
+ history := make([]*irc.Message, limit)
+ t := time.Now()
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) {
+ var offset int64 = -1
+ if afterOffset >= 0 && truncateDay(t).Equal(afterTime) {
+ offset = afterOffset
+ }
+
+ buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := t.Date()
+ t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
+ root, err := os.Open(rootPath)
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+
+ // The returned targets are escaped, and there is no way to un-escape
+ // TODO: switch to ReadDir (Go 1.16+)
+ targetNames, err := root.Readdirnames(0)
+ root.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ var targets []chatHistoryTarget
+ for _, target := range targetNames {
+ // target is already escaped here
+ targetPath := filepath.Join(rootPath, target)
+ targetDir, err := os.Open(targetPath)
+ if err != nil {
+ return nil, err
+ }
+
+ entries, err := targetDir.Readdir(0)
+ targetDir.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ // We use mtime here, which may give imprecise or incorrect results
+ var t time.Time
+ for _, entry := range entries {
+ if entry.ModTime().After(t) {
+ t = entry.ModTime()
+ }
+ }
+
+ // The timestamps we get from logs have second granularity
+ t = truncateSecond(t)
+
+ // Filter out targets that don't fullfil the time bounds
+ if !isTimeBetween(t, start, end) {
+ continue
+ }
+
+ targets = append(targets, chatHistoryTarget{
+ Name: target,
+ LatestMessage: t,
+ })
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ // Sort targets by latest message time, backwards or forwards depending on
+ // the order of the time bounds
+ sort.Slice(targets, func(i, j int) bool {
+ t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage
+ if start.Before(end) {
+ return t1.Before(t2)
+ } else {
+ return !t1.Before(t2)
+ }
+ })
+
+ // Truncate the result if necessary
+ if len(targets) > limit {
+ targets = targets[:limit]
+ }
+
+ return targets, nil
+}
+
+func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
+ oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
+ newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
+ // Avoid loosing data by overwriting an existing directory
+ if _, err := os.Stat(newDir); err == nil {
+ return fmt.Errorf("destination %q already exists", newDir)
+ }
+ return os.Rename(oldDir, newDir)
+}
+
+func truncateDay(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, 0, 0, 0, 0, t.Location())
+}
+
+func truncateSecond(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location())
+}
+
+func isTimeBetween(t, start, end time.Time) bool {
+ if end.Before(start) {
+ end, start = start, end
+ }
+ return start.Before(t) && t.Before(end)
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const messageRingBufferCap = 4096
+
+type memoryMsgID struct {
+ Seq bare.Uint
+}
+
+func (memoryMsgID) msgIDType() msgIDType {
+ return msgIDMemory
+}
+
+func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
+ var id memoryMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", 0, err
+ }
+ return netID, entity, uint64(id.Seq), nil
+}
+
+func formatMemoryMsgID(netID int64, entity string, seq uint64) string {
+ id := memoryMsgID{bare.Uint(seq)}
+ return formatMsgID(netID, entity, &id)
+}
+
+type ringBufferKey struct {
+ networkID int64
+ entity string
+}
+
+type memoryMessageStore struct {
+ buffers map[ringBufferKey]*messageRingBuffer
+}
+
+var _ messageStore = (*memoryMessageStore)(nil)
+
+func newMemoryMessageStore() *memoryMessageStore {
+ return &memoryMessageStore{
+ buffers: make(map[ringBufferKey]*messageRingBuffer),
+ }
+}
+
+func (ms *memoryMessageStore) Close() error {
+ ms.buffers = nil
+ return nil
+}
+
+func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ return rb
+ }
+ rb := newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ return rb
+}
+
+func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ var seq uint64
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ seq = rb.cur
+ }
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ // Only append these messages, because LoadLatestID shouldn't return
+ // other kinds of message.
+ default:
+ return "", nil
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ rb = newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ }
+
+ seq := rb.Append(msg)
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ _, _, seq, err := parseMemoryMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ return nil, nil
+ }
+
+ return rb.LoadLatestSeq(seq, limit)
+}
+
+type messageRingBuffer struct {
+ buf []*irc.Message
+ cur uint64
+}
+
+func newMessageRingBuffer(capacity int) *messageRingBuffer {
+ return &messageRingBuffer{
+ buf: make([]*irc.Message, capacity),
+ cur: 1,
+ }
+}
+
+func (rb *messageRingBuffer) cap() uint64 {
+ return uint64(len(rb.buf))
+}
+
+func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 {
+ seq := rb.cur
+ i := int(seq % rb.cap())
+ rb.buf[i] = msg
+ rb.cur++
+ return seq
+}
+
+func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) {
+ if seq > rb.cur {
+ return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur)
+ } else if seq == rb.cur {
+ return nil, nil
+ }
+
+ // The query excludes the message with the sequence number seq
+ diff := rb.cur - seq - 1
+ if diff > rb.cap() {
+ // We dropped diff - cap entries
+ diff = rb.cap()
+ }
+ if int(diff) > limit {
+ diff = uint64(limit)
+ }
+
+ l := make([]*irc.Message, int(diff))
+ for i := 0; i < int(diff); i++ {
+ j := int((rb.cur - diff + uint64(i)) % rb.cap())
+ l[i] = rb.buf[j]
+ }
+
+ return l, nil
+}
--- /dev/null
+//go:build !go1.16
+// +build !go1.16
+
+package suika
+
+import (
+ "strings"
+)
+
+func isErrClosed(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "use of closed network connection")
+}
--- /dev/null
+//go:build go1.16
+// +build go1.16
+
+package suika
+
+import (
+ "errors"
+ "net"
+)
+
+func isErrClosed(err error) bool {
+ return errors.Is(err, net.ErrClosed)
+}
--- /dev/null
+package suika
+
+import (
+ "math/rand"
+ "time"
+)
+
+// backoffer implements a simple exponential backoff.
+type backoffer struct {
+ min, max, jitter time.Duration
+ n int64
+}
+
+func newBackoffer(min, max, jitter time.Duration) *backoffer {
+ return &backoffer{min: min, max: max, jitter: jitter}
+}
+
+func (b *backoffer) Reset() {
+ b.n = 0
+}
+
+func (b *backoffer) Next() time.Duration {
+ if b.n == 0 {
+ b.n = 1
+ return 0
+ }
+
+ d := time.Duration(b.n) * b.min
+ if d > b.max {
+ d = b.max
+ } else {
+ b.n *= 2
+ }
+
+ if b.jitter != 0 {
+ d += time.Duration(rand.Int63n(int64(b.jitter)))
+ }
+
+ return d
+}
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+desc="A drunk IRC bouncer"
+rcvar="suika_enable"
+
+: ${suika_user="ircd"}
+
+command="%%PREFIX%%/bin/suika"
+pidfile="/var/run/suika.pid"
+required_files="%%PREFIX%%/etc/suika/config"
+
+start_cmd="suika_start"
+
+suika_start() {
+ /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files}
+}
+
+load_rc_config "$name"
+run_rc_command "$1"
--- /dev/null
+# $TheSupernovaDuo$
+cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+user: ircd
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+rcvar="${name}"
+command="%%PREFIX/bin/${name}"
+command_args="--config %%PREFIX%%/etc/suika/config"
+pidfile="/var/run/${name}.pid"
+start_cmd="${name}_start"
+
+suika_start() {
+ printf "Starting %s..." "${name}"
+ ${command} ${command_args}
+ pgrep -n ${name} > ${pidfile}
+}
+
+load_rc_config ${name}
+run_rc_command "$1"
+
+
--- /dev/null
+#!/bin/ksh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+daemon="%%PREFIX%%/bin/suika"
+daemon_args="--config %%PREFIX%%/etc/suika/config"
+
+. /etc/rc.d/rc.subr
+
+rc_bg=YES
+
+rc_cmd "$1"
--- /dev/null
+# $TheSupernovaDuo$
+# vim: ft=confini
+[Unit]
+Description=A drunk IRC bouncer
+After=network.target
+Wants=network.target
+StartLimitBurst=5
+StartLimitIntervalSec=1
+[Service]
+Type=simple
+Restart=on-abnormal
+RestartSec=1
+User=suika
+ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+[Install]
+WantedBy=multi-user.target
--- /dev/null
+package suika
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "runtime/debug"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+// TODO: make configurable
+var (
+ retryConnectMinDelay = time.Minute
+ retryConnectMaxDelay = 10 * time.Minute
+ retryConnectJitter = time.Minute
+ connectTimeout = 15 * time.Second
+ writeTimeout = 10 * time.Second
+ upstreamMessageDelay = 2 * time.Second
+ upstreamMessageBurst = 10
+ backlogTimeout = 10 * time.Second
+ handleDownstreamMessageTimeout = 10 * time.Second
+ downstreamRegisterTimeout = 30 * time.Second
+ chatHistoryLimit = 1000
+ backlogLimit = 4000
+)
+
+type Logger interface {
+ Printf(format string, v ...interface{})
+ Debugf(format string, v ...interface{})
+}
+
+type logger struct {
+ *log.Logger
+ debug bool
+}
+
+func (l logger) Debugf(format string, v ...interface{}) {
+ if !l.debug {
+ return
+ }
+ l.Logger.Printf(format, v...)
+}
+
+func NewLogger(out io.Writer, debug bool) Logger {
+ return logger{
+ Logger: log.New(log.Writer(), "", log.LstdFlags),
+ debug: debug,
+ }
+}
+
+type prefixLogger struct {
+ logger Logger
+ prefix string
+}
+
+var _ Logger = (*prefixLogger)(nil)
+
+func (l *prefixLogger) Printf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Printf("%v"+format, v...)
+}
+
+func (l *prefixLogger) Debugf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Debugf("%v"+format, v...)
+}
+
+type int64Gauge struct {
+ v int64 // atomic
+}
+
+func (g *int64Gauge) Add(delta int64) {
+ atomic.AddInt64(&g.v, delta)
+}
+
+func (g *int64Gauge) Value() int64 {
+ return atomic.LoadInt64(&g.v)
+}
+
+func (g *int64Gauge) Float64() float64 {
+ return float64(g.Value())
+}
+
+type retryListener struct {
+ net.Listener
+ Logger Logger
+
+ delay time.Duration
+}
+
+func (ln *retryListener) Accept() (net.Conn, error) {
+ for {
+ conn, err := ln.Listener.Accept()
+ if ne, ok := err.(net.Error); ok && ne.Temporary() {
+ if ln.delay == 0 {
+ ln.delay = 5 * time.Millisecond
+ } else {
+ ln.delay *= 2
+ }
+ if max := 1 * time.Second; ln.delay > max {
+ ln.delay = max
+ }
+ if ln.Logger != nil {
+ ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err)
+ }
+ time.Sleep(ln.delay)
+ } else {
+ ln.delay = 0
+ return conn, err
+ }
+ }
+}
+
+type Config struct {
+ Hostname string
+ Title string
+ LogPath string
+ MaxUserNetworks int
+ MultiUpstream bool
+ MOTD string
+ UpstreamUserIPs []*net.IPNet
+}
+
+type Server struct {
+ Logger Logger
+
+ config atomic.Value // *Config
+ db Database
+ stopWG sync.WaitGroup
+
+ lock sync.Mutex
+ listeners map[net.Listener]struct{}
+ users map[string]*user
+}
+
+func NewServer(db Database) *Server {
+ srv := &Server{
+ Logger: NewLogger(log.Writer(), true),
+ db: db,
+ listeners: make(map[net.Listener]struct{}),
+ users: make(map[string]*user),
+ }
+ srv.config.Store(&Config{
+ Hostname: "localhost",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ })
+ return srv
+}
+
+func (s *Server) prefix() *irc.Prefix {
+ return &irc.Prefix{Name: s.Config().Hostname}
+}
+
+func (s *Server) Config() *Config {
+ return s.config.Load().(*Config)
+}
+
+func (s *Server) SetConfig(cfg *Config) {
+ s.config.Store(cfg)
+}
+
+func (s *Server) Start() error {
+ users, err := s.db.ListUsers(context.TODO())
+ if err != nil {
+ return err
+ }
+
+ s.lock.Lock()
+ for i := range users {
+ s.addUserLocked(&users[i])
+ }
+ s.lock.Unlock()
+
+ return nil
+}
+
+func (s *Server) Shutdown() {
+ s.lock.Lock()
+ for ln := range s.listeners {
+ if err := ln.Close(); err != nil {
+ s.Logger.Printf("failed to stop listener: %v", err)
+ }
+ }
+ for _, u := range s.users {
+ u.events <- eventStop{}
+ }
+ s.lock.Unlock()
+
+ s.stopWG.Wait()
+
+ if err := s.db.Close(); err != nil {
+ s.Logger.Printf("failed to close DB: %v", err)
+ }
+}
+
+func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if _, ok := s.users[user.Username]; ok {
+ return nil, fmt.Errorf("user %q already exists", user.Username)
+ }
+
+ err := s.db.StoreUser(ctx, user)
+ if err != nil {
+ return nil, fmt.Errorf("could not create user in db: %v", err)
+ }
+
+ return s.addUserLocked(user), nil
+}
+
+func (s *Server) forEachUser(f func(*user)) {
+ s.lock.Lock()
+ for _, u := range s.users {
+ f(u)
+ }
+ s.lock.Unlock()
+}
+
+func (s *Server) getUser(name string) *user {
+ s.lock.Lock()
+ u := s.users[name]
+ s.lock.Unlock()
+ return u
+}
+
+func (s *Server) addUserLocked(user *User) *user {
+ s.Logger.Printf("starting bouncer for user %q", user.Username)
+ u := newUser(s, user)
+ s.users[u.Username] = u
+
+ s.stopWG.Add(1)
+
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
+ }
+
+ s.lock.Lock()
+ delete(s.users, u.Username)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ u.run()
+ }()
+
+ return u
+}
+
+var lastDownstreamID uint64 = 0
+
+func (s *Server) handle(ic ircConn) {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
+ }
+ }()
+
+ id := atomic.AddUint64(&lastDownstreamID, 1)
+ dc := newDownstreamConn(s, ic, id)
+ if err := dc.runUntilRegistered(); err != nil {
+ if !errors.Is(err, io.EOF) {
+ dc.logger.Printf("%v", err)
+ }
+ } else {
+ dc.user.events <- eventDownstreamConnected{dc}
+ if err := dc.readMessages(dc.user.events); err != nil {
+ dc.logger.Printf("%v", err)
+ }
+ dc.user.events <- eventDownstreamDisconnected{dc}
+ }
+ dc.Close()
+}
+
+func (s *Server) Serve(ln net.Listener) error {
+ ln = &retryListener{
+ Listener: ln,
+ Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())},
+ }
+
+ s.lock.Lock()
+ s.listeners[ln] = struct{}{}
+ s.lock.Unlock()
+
+ s.stopWG.Add(1)
+
+ defer func() {
+ s.lock.Lock()
+ delete(s.listeners, ln)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ for {
+ conn, err := ln.Accept()
+ if isErrClosed(err) {
+ return nil
+ } else if err != nil {
+ return fmt.Errorf("failed to accept connection: %v", err)
+ }
+
+ go s.handle(newNetIRCConn(conn))
+ }
+}
+
+type ServerStats struct {
+ Users int
+ Downstreams int64
+ Upstreams int64
+}
+
+func (s *Server) Stats() *ServerStats {
+ var stats ServerStats
+ s.lock.Lock()
+ stats.Users = len(s.users)
+ s.lock.Unlock()
+ return &stats
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "net"
+ "testing"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+var testServerPrefix = &irc.Prefix{Name: "suika-test-server"}
+
+const (
+ testUsername = "suika-test-user"
+ testPassword = testUsername
+)
+
+func createTempSqliteDB(t *testing.T) Database {
+ db, err := OpenDB("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+ // :memory: will open a separate database for each new connection. Make
+ // sure the sql package only uses a single connection. An alternative
+ // solution is to use "file::memory:?cache=shared".
+ db.(*SqliteDB).db.SetMaxOpenConns(1)
+ return db
+}
+
+func createTempPostgresDB(t *testing.T) Database {
+ db := &PostgresDB{db: openTempPostgresDB(t)}
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("failed to upgrade PostgreSQL database: %v", err)
+ }
+
+ return db
+}
+
+func createTestUser(t *testing.T, db Database) *User {
+ hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
+ if err != nil {
+ t.Fatalf("failed to generate bcrypt hash: %v", err)
+ }
+
+ record := &User{Username: testUsername, Password: string(hashed)}
+ if err := db.StoreUser(context.Background(), record); err != nil {
+ t.Fatalf("failed to store test user: %v", err)
+ }
+
+ return record
+}
+
+func createTestDownstream(t *testing.T, srv *Server) ircConn {
+ c1, c2 := net.Pipe()
+ go srv.handle(newNetIRCConn(c1))
+ return newNetIRCConn(c2)
+}
+
+func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("failed to create TCP listener: %v", err)
+ }
+
+ network := &Network{
+ Name: "testnet",
+ Addr: "irc://" + ln.Addr().String(),
+ Nick: user.Username,
+ Enabled: true,
+ }
+ if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil {
+ t.Fatalf("failed to store test network: %v", err)
+ }
+
+ return network, ln
+}
+
+func mustAccept(t *testing.T, ln net.Listener) ircConn {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("failed accepting connection: %v", err)
+ }
+ return newNetIRCConn(c)
+}
+
+func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
+ msg, err := c.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message (want %q): %v", cmd, err)
+ }
+ if msg.Command != cmd {
+ t.Fatalf("invalid message received: want %q, got: %v", cmd, msg)
+ }
+ return msg
+}
+
+func registerDownstreamConn(t *testing.T, c ircConn, network *Network) {
+ c.WriteMessage(&irc.Message{
+ Command: "PASS",
+ Params: []string{testPassword},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "NICK",
+ Params: []string{testUsername},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "USER",
+ Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername},
+ })
+
+ expectMessage(t, c, irc.RPL_WELCOME)
+}
+
+func registerUpstreamConn(t *testing.T, c ircConn) {
+ msg := expectMessage(t, c, "CAP")
+ if msg.Params[0] != "LS" {
+ t.Fatalf("invalid CAP LS: got: %v", msg)
+ }
+ msg = expectMessage(t, c, "NICK")
+ nick := msg.Params[0]
+ if nick != testUsername {
+ t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg)
+ }
+ expectMessage(t, c, "USER")
+
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_WELCOME,
+ Params: []string{nick, "Welcome!"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_YOURHOST,
+ Params: []string{nick, "Your host is suika-test-server"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_CREATED,
+ Params: []string{nick, "Who cares when the server was created?"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_MYINFO,
+ Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.ERR_NOMOTD,
+ Params: []string{nick, "No MOTD"},
+ })
+}
+
+func testServer(t *testing.T, db Database) {
+ user := createTestUser(t, db)
+ network, upstream := createTestUpstream(t, db, user)
+ defer upstream.Close()
+
+ srv := NewServer(db)
+ if err := srv.Start(); err != nil {
+ t.Fatalf("failed to start server: %v", err)
+ }
+ defer srv.Shutdown()
+
+ uc := mustAccept(t, upstream)
+ defer uc.Close()
+ registerUpstreamConn(t, uc)
+
+ dc := createTestDownstream(t, srv)
+ defer dc.Close()
+ registerDownstreamConn(t, dc, network)
+
+ noticeText := "This is a very important server notice."
+ uc.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: "NOTICE",
+ Params: []string{testUsername, noticeText},
+ })
+
+ var msg *irc.Message
+ for {
+ var err error
+ msg, err = dc.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message: %v", err)
+ }
+ if msg.Command == "NOTICE" {
+ break
+ }
+ }
+
+ if msg.Params[1] != noticeText {
+ t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg)
+ }
+}
+
+func TestServer(t *testing.T) {
+ t.Run("sqlite", func(t *testing.T) {
+ db := createTempSqliteDB(t)
+ testServer(t, db)
+ })
+
+ t.Run("postgres", func(t *testing.T) {
+ db := createTempPostgresDB(t)
+ testServer(t, db)
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/hex"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ serviceNick = "BouncerServ"
+ serviceNickCM = "bouncerserv"
+ serviceRealname = "suika bouncer service"
+)
+
+// maxRSABits is the maximum number of RSA key bits used when generating a new
+// private key.
+const maxRSABits = 8192
+
+var servicePrefix = &irc.Prefix{
+ Name: serviceNick,
+ User: serviceNick,
+ Host: serviceNick,
+}
+
+type serviceCommandSet map[string]*serviceCommand
+
+type serviceCommand struct {
+ usage string
+ desc string
+ handle func(ctx context.Context, dc *downstreamConn, params []string) error
+ children serviceCommandSet
+ admin bool
+}
+
+func sendServiceNOTICE(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func sendServicePRIVMSG(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "PRIVMSG",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func splitWords(s string) ([]string, error) {
+ var words []string
+ var lastWord strings.Builder
+ escape := false
+ prev := ' '
+ wordDelim := ' '
+
+ for _, r := range s {
+ if escape {
+ // last char was a backslash, write the byte as-is.
+ lastWord.WriteRune(r)
+ escape = false
+ } else if r == '\\' {
+ escape = true
+ } else if wordDelim == ' ' && unicode.IsSpace(r) {
+ // end of last word
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ lastWord.Reset()
+ }
+ } else if r == wordDelim {
+ // wordDelim is either " or ', switch back to
+ // space-delimited words.
+ wordDelim = ' '
+ } else if r == '"' || r == '\'' {
+ if wordDelim == ' ' {
+ // start of (double-)quoted word
+ wordDelim = r
+ } else {
+ // either wordDelim is " and r is ' or vice-versa
+ lastWord.WriteRune(r)
+ }
+ } else {
+ lastWord.WriteRune(r)
+ }
+
+ prev = r
+ }
+
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ }
+
+ if wordDelim != ' ' {
+ return nil, fmt.Errorf("unterminated quoted string")
+ }
+ if escape {
+ return nil, fmt.Errorf("unterminated backslash sequence")
+ }
+
+ return words, nil
+}
+
+func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) {
+ words, err := splitWords(text)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err))
+ return
+ }
+
+ cmd, params, err := serviceCommands.Get(words)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err))
+ return
+ }
+ if cmd.admin && !dc.user.Admin {
+ sendServicePRIVMSG(dc, "error: you must be an admin to use this command")
+ return
+ }
+
+ if cmd.handle == nil {
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ // Pretend the command does not exist if it has neither children nor handler.
+ // This is obviously a bug but it is better to not die anyway.
+ dc.logger.Printf("command without handler and subcommands invoked:", words[0])
+ sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0]))
+ }
+ return
+ }
+
+ if err := cmd.handle(ctx, dc, params); err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err))
+ }
+}
+
+func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) {
+ if len(params) == 0 {
+ return nil, nil, fmt.Errorf("no command specified")
+ }
+
+ name := params[0]
+ params = params[1:]
+
+ cmd, ok := cmds[name]
+ if !ok {
+ for k := range cmds {
+ if !strings.HasPrefix(k, name) {
+ continue
+ }
+ if cmd != nil {
+ return nil, params, fmt.Errorf("command %q is ambiguous", name)
+ }
+ cmd = cmds[k]
+ }
+ }
+ if cmd == nil {
+ return nil, params, fmt.Errorf("command %q not found", name)
+ }
+
+ if len(params) == 0 || len(cmd.children) == 0 {
+ return cmd, params, nil
+ }
+ return cmd.children.Get(params)
+}
+
+func (cmds serviceCommandSet) Names() []string {
+ l := make([]string, 0, len(cmds))
+ for name := range cmds {
+ l = append(l, name)
+ }
+ sort.Strings(l)
+ return l
+}
+
+var serviceCommands serviceCommandSet
+
+func init() {
+ serviceCommands = serviceCommandSet{
+ "help": {
+ usage: "[command]",
+ desc: "print help message",
+ handle: handleServiceHelp,
+ },
+ "network": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "add a new network",
+ handle: handleServiceNetworkCreate,
+ },
+ "status": {
+ desc: "show a list of saved networks and their current status",
+ handle: handleServiceNetworkStatus,
+ },
+ "update": {
+ usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "update a network",
+ handle: handleServiceNetworkUpdate,
+ },
+ "delete": {
+ usage: "[name]",
+ desc: "delete a network",
+ handle: handleServiceNetworkDelete,
+ },
+ "quote": {
+ usage: "[name] <command>",
+ desc: "send a raw line to a network",
+ handle: handleServiceNetworkQuote,
+ },
+ },
+ },
+ "certfp": {
+ children: serviceCommandSet{
+ "generate": {
+ usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]",
+ desc: "generate a new self-signed certificate, defaults to using RSA-3072 key",
+ handle: handleServiceCertFPGenerate,
+ },
+ "fingerprint": {
+ usage: "[-network name]",
+ desc: "show fingerprints of certificate",
+ handle: handleServiceCertFPFingerprints,
+ },
+ },
+ },
+ "sasl": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show SASL status",
+ handle: handleServiceSASLStatus,
+ },
+ "set-plain": {
+ usage: "[-network name] <username> <password>",
+ desc: "set SASL PLAIN credentials",
+ handle: handleServiceSASLSetPlain,
+ },
+ "reset": {
+ usage: "[-network name]",
+ desc: "disable SASL authentication and remove stored credentials",
+ handle: handleServiceSASLReset,
+ },
+ },
+ },
+ "user": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-username <username> -password <password> [-realname <realname>] [-admin]",
+ desc: "create a new suika user",
+ handle: handleUserCreate,
+ admin: true,
+ },
+ "update": {
+ usage: "[-password <password>] [-realname <realname>]",
+ desc: "update the current user",
+ handle: handleUserUpdate,
+ },
+ "delete": {
+ usage: "<username>",
+ desc: "delete a user",
+ handle: handleUserDelete,
+ admin: true,
+ },
+ },
+ },
+ "channel": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show a list of saved channels and their current status",
+ handle: handleServiceChannelStatus,
+ },
+ "update": {
+ usage: "<name> [-relay-detached <default|none|highlight|message>] [-reattach-on <default|none|highlight|message>] [-detach-after <duration>] [-detach-on <default|none|highlight|message>]",
+ desc: "update a channel",
+ handle: handleServiceChannelUpdate,
+ },
+ },
+ },
+ "server": {
+ children: serviceCommandSet{
+ "status": {
+ desc: "show server statistics",
+ handle: handleServiceServerStatus,
+ admin: true,
+ },
+ "notice": {
+ desc: "broadcast a notice to all connected bouncer users",
+ handle: handleServiceServerNotice,
+ admin: true,
+ },
+ },
+ admin: true,
+ },
+ }
+}
+
+func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) {
+ for _, name := range cmds.Names() {
+ cmd := cmds[name]
+ if cmd.admin && !admin {
+ continue
+ }
+ words := append(prefix, name)
+ if len(cmd.children) == 0 {
+ s := strings.Join(words, " ")
+ *l = append(*l, s)
+ } else {
+ appendServiceCommandSetHelp(cmd.children, words, admin, l)
+ }
+ }
+}
+
+func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) > 0 {
+ cmd, rest, err := serviceCommands.Get(params)
+ if err != nil {
+ return err
+ }
+ words := params[:len(params)-len(rest)]
+
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ text := strings.Join(words, " ")
+ if cmd.usage != "" {
+ text += " " + cmd.usage
+ }
+ text += ": " + cmd.desc
+
+ sendServicePRIVMSG(dc, text)
+ }
+ } else {
+ var l []string
+ appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ }
+ return nil
+}
+
+func newFlagSet() *flag.FlagSet {
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.SetOutput(ioutil.Discard)
+ return fs
+}
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+// stringPtrFlag is a flag value populating a string pointer. This allows to
+// disambiguate between a flag that hasn't been set and a flag that has been
+// set to an empty string.
+type stringPtrFlag struct {
+ ptr **string
+}
+
+func (f stringPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return ""
+ }
+ return **f.ptr
+}
+
+func (f stringPtrFlag) Set(s string) error {
+ *f.ptr = &s
+ return nil
+}
+
+type boolPtrFlag struct {
+ ptr **bool
+}
+
+func (f boolPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return "<nil>"
+ }
+ return strconv.FormatBool(**f.ptr)
+}
+
+func (f boolPtrFlag) Set(s string) error {
+ v, err := strconv.ParseBool(s)
+ if err != nil {
+ return err
+ }
+ *f.ptr = &v
+ return nil
+}
+
+func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) {
+ name, params := popArg(params)
+ if name == "" {
+ if dc.network == nil {
+ return nil, params, fmt.Errorf("no network selected, a name argument is required")
+ }
+ return dc.network, params, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, params, fmt.Errorf("unknown network %q", name)
+ }
+ return net, params, nil
+ }
+}
+
+type networkFlagSet struct {
+ *flag.FlagSet
+ Addr, Name, Nick, Username, Pass, Realname *string
+ Enabled *bool
+ ConnectCommands []string
+}
+
+func newNetworkFlagSet() *networkFlagSet {
+ fs := &networkFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
+ fs.Var(stringPtrFlag{&fs.Name}, "name", "")
+ fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
+ fs.Var(stringPtrFlag{&fs.Username}, "username", "")
+ fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
+ fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
+ fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "")
+ fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
+ return fs
+}
+
+func (fs *networkFlagSet) update(network *Network) error {
+ if fs.Addr != nil {
+ if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
+ scheme := addrParts[0]
+ switch scheme {
+ case "ircs", "irc", "unix":
+ default:
+ return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme)
+ }
+ }
+ network.Addr = *fs.Addr
+ }
+ if fs.Name != nil {
+ network.Name = *fs.Name
+ }
+ if fs.Nick != nil {
+ network.Nick = *fs.Nick
+ }
+ if fs.Username != nil {
+ network.Username = *fs.Username
+ }
+ if fs.Pass != nil {
+ network.Pass = *fs.Pass
+ }
+ if fs.Realname != nil {
+ network.Realname = *fs.Realname
+ }
+ if fs.Enabled != nil {
+ network.Enabled = *fs.Enabled
+ }
+ if fs.ConnectCommands != nil {
+ if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
+ network.ConnectCommands = nil
+ } else {
+ for _, command := range fs.ConnectCommands {
+ _, err := irc.ParseMessage(command)
+ if err != nil {
+ return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
+ }
+ }
+ network.ConnectCommands = fs.ConnectCommands
+ }
+ }
+ return nil
+}
+
+func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if fs.Addr == nil {
+ return fmt.Errorf("flag -addr is required")
+ }
+
+ record := &Network{
+ Addr: *fs.Addr,
+ Enabled: true,
+ }
+ if err := fs.update(record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return fmt.Errorf("could not create network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ n := 0
+ for _, net := range dc.user.networks {
+ var statuses []string
+ var details string
+ if uc := net.conn; uc != nil {
+ if dc.nick != uc.nick {
+ statuses = append(statuses, "connected as "+uc.nick)
+ } else {
+ statuses = append(statuses, "connected")
+ }
+ details = fmt.Sprintf("%v channels", uc.channels.Len())
+ } else if !net.Enabled {
+ statuses = append(statuses, "disabled")
+ } else {
+ statuses = append(statuses, "disconnected")
+ if net.lastError != nil {
+ details = net.lastError.Error()
+ }
+ }
+
+ if net == dc.network {
+ statuses = append(statuses, "current")
+ }
+
+ name := net.GetName()
+ if name != net.Addr {
+ name = fmt.Sprintf("%v (%v)", name, net.Addr)
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", "))
+ if details != "" {
+ s += ": " + details
+ }
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, `No network configured, add one with "network create".`)
+ }
+
+ return nil
+}
+
+func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := fs.update(&record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return fmt.Errorf("could not update network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName()))
+ return nil
+}
+
+func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 && len(params) != 2 {
+ return fmt.Errorf("expected one or two arguments")
+ }
+
+ raw := params[len(params)-1]
+ params = params[:len(params)-1]
+
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ uc := net.conn
+ if uc == nil {
+ return fmt.Errorf("network %q is not currently connected", net.GetName())
+ }
+
+ m, err := irc.ParseMessage(raw)
+ if err != nil {
+ return fmt.Errorf("failed to parse command %q: %v", raw, err)
+ }
+ uc.SendMessage(ctx, m)
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName()))
+ return nil
+}
+
+func sendCertfpFingerprints(dc *downstreamConn, cert []byte) {
+ sha1Sum := sha1.Sum(cert)
+ sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
+ sha256Sum := sha256.Sum256(cert)
+ sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))
+ sha512Sum := sha512.Sum512(cert)
+ sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:]))
+}
+
+func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) {
+ if name == "" {
+ if dc.network == nil {
+ return nil, fmt.Errorf("no network selected, -network is required")
+ }
+ return dc.network, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, fmt.Errorf("unknown network %q", name)
+ }
+ return net, nil
+ }
+}
+
+func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+ keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)")
+ bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if *bits <= 0 || *bits > maxRSABits {
+ return fmt.Errorf("invalid value for -bits")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ privKey, cert, err := generateCertFP(*keyType, *bits)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.External.CertBlob = cert
+ net.SASL.External.PrivKeyBlob = privKey
+ net.SASL.Mechanism = "EXTERNAL"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "certificate generated")
+ sendCertfpFingerprints(dc, cert)
+ return nil
+}
+
+func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ if net.SASL.Mechanism != "EXTERNAL" {
+ return fmt.Errorf("CertFP not set up")
+ }
+
+ sendCertfpFingerprints(dc, net.SASL.External.CertBlob)
+ return nil
+}
+
+func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ switch net.SASL.Mechanism {
+ case "PLAIN":
+ sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username))
+ case "EXTERNAL":
+ sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled")
+ case "":
+ sendServicePRIVMSG(dc, "SASL is disabled")
+ }
+
+ if uc := net.conn; uc != nil {
+ if uc.account != "" {
+ sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account))
+ } else {
+ sendServicePRIVMSG(dc, "Unauthenticated on upstream network")
+ }
+ } else {
+ sendServicePRIVMSG(dc, "Disconnected from upstream network")
+ }
+
+ return nil
+}
+
+func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if len(fs.Args()) != 2 {
+ return fmt.Errorf("expected exactly 2 arguments")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = fs.Arg(0)
+ net.SASL.Plain.Password = fs.Arg(1)
+ net.SASL.Mechanism = "PLAIN"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials saved")
+ return nil
+}
+
+func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = ""
+ net.SASL.Plain.Password = ""
+ net.SASL.External.CertBlob = nil
+ net.SASL.External.PrivKeyBlob = nil
+ net.SASL.Mechanism = ""
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials reset")
+ return nil
+}
+
+func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ username := fs.String("username", "", "")
+ password := fs.String("password", "", "")
+ realname := fs.String("realname", "", "")
+ admin := fs.Bool("admin", false, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if *username == "" {
+ return fmt.Errorf("flag -username is required")
+ }
+ if *password == "" {
+ return fmt.Errorf("flag -password is required")
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+
+ user := &User{
+ Username: *username,
+ Password: string(hashed),
+ Realname: *realname,
+ Admin: *admin,
+ }
+ if _, err := dc.srv.createUser(ctx, user); err != nil {
+ return fmt.Errorf("could not create user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username))
+ return nil
+}
+
+func popArg(params []string) (string, []string) {
+ if len(params) > 0 && !strings.HasPrefix(params[0], "-") {
+ return params[0], params[1:]
+ }
+ return "", params
+}
+
+func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ var password, realname *string
+ var admin *bool
+ fs := newFlagSet()
+ fs.Var(stringPtrFlag{&password}, "password", "")
+ fs.Var(stringPtrFlag{&realname}, "realname", "")
+ fs.Var(boolPtrFlag{&admin}, "admin", "")
+
+ username, params := popArg(params)
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if len(fs.Args()) > 0 {
+ return fmt.Errorf("unexpected argument")
+ }
+
+ var hashed *string
+ if password != nil {
+ hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+ hashedStr := string(hashedBytes)
+ hashed = &hashedStr
+ }
+
+ if username != "" && username != dc.user.Username {
+ if !dc.user.Admin {
+ return fmt.Errorf("you must be an admin to update other users")
+ }
+ if realname != nil {
+ return fmt.Errorf("cannot update -realname of other user")
+ }
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ done := make(chan error, 1)
+ event := eventUserUpdate{
+ password: hashed,
+ admin: admin,
+ done: done,
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case u.events <- event:
+ }
+ // TODO: send context to the other side
+ if err := <-done; err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username))
+ } else {
+ // copy the user record because we'll mutate it
+ record := dc.user.User
+
+ if hashed != nil {
+ record.Password = *hashed
+ }
+ if realname != nil {
+ record.Realname = *realname
+ }
+ if admin != nil {
+ return fmt.Errorf("cannot update -admin of own user")
+ }
+
+ if err := dc.user.updateUser(ctx, &record); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username))
+ }
+
+ return nil
+}
+
+func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ username := params[0]
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ u.stop()
+
+ if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil {
+ return fmt.Errorf("failed to delete user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username))
+ return nil
+}
+
+func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ var defaultNetworkName string
+ if dc.network != nil {
+ defaultNetworkName = dc.network.GetName()
+ }
+
+ fs := newFlagSet()
+ networkName := fs.String("network", defaultNetworkName, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ n := 0
+
+ sendNetwork := func(net *network) {
+ var channels []*Channel
+ for _, entry := range net.channels.innerMap {
+ channels = append(channels, entry.value.(*Channel))
+ }
+
+ sort.Slice(channels, func(i, j int) bool {
+ return strings.ReplaceAll(channels[i].Name, "#", "") <
+ strings.ReplaceAll(channels[j].Name, "#", "")
+ })
+
+ for _, ch := range channels {
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+ }
+
+ name := ch.Name
+ if *networkName == "" {
+ name += "/" + net.GetName()
+ }
+
+ var status string
+ if uch != nil {
+ status = "joined"
+ } else if net.conn != nil {
+ status = "parted"
+ } else {
+ status = "disconnected"
+ }
+
+ if ch.Detached {
+ status += ", detached"
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, status)
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+ }
+
+ if *networkName == "" {
+ for _, net := range dc.user.networks {
+ sendNetwork(net)
+ }
+ } else {
+ net := dc.user.getNetwork(*networkName)
+ if net == nil {
+ return fmt.Errorf("unknown network %q", *networkName)
+ }
+ sendNetwork(net)
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, "No channel configured.")
+ }
+
+ return nil
+}
+
+type channelFlagSet struct {
+ *flag.FlagSet
+ RelayDetached, ReattachOn, DetachAfter, DetachOn *string
+}
+
+func newChannelFlagSet() *channelFlagSet {
+ fs := &channelFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "")
+ fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "")
+ fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "")
+ fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "")
+ return fs
+}
+
+func (fs *channelFlagSet) update(channel *Channel) error {
+ if fs.RelayDetached != nil {
+ filter, err := parseFilter(*fs.RelayDetached)
+ if err != nil {
+ return err
+ }
+ channel.RelayDetached = filter
+ }
+ if fs.ReattachOn != nil {
+ filter, err := parseFilter(*fs.ReattachOn)
+ if err != nil {
+ return err
+ }
+ channel.ReattachOn = filter
+ }
+ if fs.DetachAfter != nil {
+ dur, err := time.ParseDuration(*fs.DetachAfter)
+ if err != nil || dur < 0 {
+ return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter)
+ }
+ channel.DetachAfter = dur
+ }
+ if fs.DetachOn != nil {
+ filter, err := parseFilter(*fs.DetachOn)
+ if err != nil {
+ return err
+ }
+ channel.DetachOn = filter
+ }
+ return nil
+}
+
+func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) < 1 {
+ return fmt.Errorf("expected at least one argument")
+ }
+ name := params[0]
+
+ fs := newChannelFlagSet()
+ if err := fs.Parse(params[1:]); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ if err := fs.update(ch); err != nil {
+ return err
+ }
+
+ uc.updateChannelAutoDetach(upstreamName)
+
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ return fmt.Errorf("failed to update channel: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name))
+ return nil
+}
+func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ dbStats, err := dc.user.srv.db.Stats(ctx)
+ if err != nil {
+ return err
+ }
+ serverStats := dc.user.srv.Stats()
+ sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels))
+ return nil
+}
+
+func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ text := params[0]
+
+ dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text)
+
+ broadcastMsg := &irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{"$" + dc.srv.Config().Hostname, text},
+ }
+ var err error
+ sent := 0
+ total := 0
+ dc.srv.forEachUser(func(u *user) {
+ total++
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ case u.events <- eventBroadcast{broadcastMsg}:
+ sent++
+ }
+ })
+
+ dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total)
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total))
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func assertSplit(t *testing.T, input string, expected []string) {
+ actual, err := splitWords(input)
+ if err != nil {
+ t.Errorf("%q: %v", input, err)
+ return
+ }
+ if len(actual) != len(expected) {
+ t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual)
+ return
+ }
+ for i := 0; i < len(actual); i++ {
+ if actual[i] != expected[i] {
+ t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual)
+ }
+ }
+}
+
+func TestSplit(t *testing.T) {
+ assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{
+ "ch",
+ "up",
+ "#suika",
+ "relay-detached",
+ "message",
+ })
+ assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{
+ "net",
+ "update",
+ "\"free\"node",
+ "-pass",
+ "political \"stance\" desu!",
+ "-realname",
+ "",
+ "-nick",
+ "lee",
+ })
+ assertSplit(t, "Omedeto,\\ Yui! ''", []string{
+ "Omedeto, Yui!",
+ "",
+ })
+
+ if _, err := splitWords("end of 'file"); err == nil {
+ t.Errorf("expected error on unterminated single quote")
+ }
+ if _, err := splitWords("end of backquote \\"); err == nil {
+ t.Errorf("expected error on unterminated backquote sequence")
+ }
+}
--- /dev/null
+CREATE TABLE IF NOT EXISTS "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+
+CREATE TABLE IF NOT EXISTS "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255),
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism sasl_mechanism,
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA,
+ sasl_external_key BYTEA,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+CREATE TABLE IF NOT EXISTS "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+CREATE TABLE IF NOT EXISTS "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+);
+
--- /dev/null
+CREATE TABLE IF NOT EXISTS User (
+ id INTEGER PRIMARY KEY,
+ username TEXT NOT NULL UNIQUE,
+ password TEXT,
+ admin INTEGER NOT NULL DEFAULT 0,
+ realname TEXT
+);
+CREATE TABLE IF NOT EXISTS Network (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+CREATE TABLE IF NOT EXISTS Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name TEXT NOT NULL,
+ key TEXT,
+ detached INTEGER NOT NULL DEFAULT 0,
+ detached_internal_msgid TEXT,
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ client TEXT,
+ internal_msgid TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+);
+
+CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+);
+
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto"
+ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "gopkg.in/irc.v3"
+)
+
+// permanentUpstreamCaps is the static list of upstream capabilities always
+// requested when supported.
+var permanentUpstreamCaps = map[string]bool{
+ "account-notify": true,
+ "account-tag": true,
+ "away-notify": true,
+ "batch": true,
+ "extended-join": true,
+ "invite-notify": true,
+ "labeled-response": true,
+ "message-tags": true,
+ "multi-prefix": true,
+ "sasl": true,
+ "server-time": true,
+ "setname": true,
+
+ "draft/account-registration": true,
+ "draft/extended-monitor": true,
+}
+
+type registrationError struct {
+ *irc.Message
+}
+
+func (err registrationError) Error() string {
+ return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason())
+}
+
+func (err registrationError) Reason() string {
+ if len(err.Params) > 0 {
+ return err.Params[len(err.Params)-1]
+ }
+ return err.Command
+}
+
+func (err registrationError) Temporary() bool {
+ // Only return false if we're 100% sure that fixing the error requires a
+ // network configuration change
+ switch err.Command {
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME:
+ return false
+ case "FAIL":
+ return err.Params[1] != "ACCOUNT_REQUIRED"
+ default:
+ return true
+ }
+}
+
+type upstreamChannel struct {
+ Name string
+ conn *upstreamConn
+ Topic string
+ TopicWho *irc.Prefix
+ TopicTime time.Time
+ Status channelStatus
+ modes channelModes
+ creationTime string
+ Members membershipsCasemapMap
+ complete bool
+ detachTimer *time.Timer
+}
+
+func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) {
+ if uc.detachTimer != nil {
+ uc.detachTimer.Stop()
+ uc.detachTimer = nil
+ }
+
+ if dur == 0 {
+ return
+ }
+
+ uc.detachTimer = time.AfterFunc(dur, func() {
+ uc.conn.network.user.events <- eventChannelDetach{
+ uc: uc.conn,
+ name: uc.Name,
+ }
+ })
+}
+
+type pendingUpstreamCommand struct {
+ downstreamID uint64
+ msg *irc.Message
+}
+
+type upstreamConn struct {
+ conn
+
+ network *network
+ user *user
+
+ serverName string
+ availableUserModes string
+ availableChannelModes map[byte]channelModeType
+ availableChannelTypes string
+ availableMemberships []membership
+ isupport map[string]*string
+
+ registered bool
+ nick string
+ nickCM string
+ username string
+ realname string
+ modes userModes
+ channels upstreamChannelCasemapMap
+ supportedCaps map[string]string
+ caps map[string]bool
+ batches map[string]batch
+ away bool
+ account string
+ nextLabelID uint64
+ monitored monitorCasemapMap
+
+ saslClient sasl.Client
+ saslStarted bool
+
+ casemapIsSet bool
+
+ // Queue of commands in progress, indexed by type. The first entry has been
+ // sent to the server and is awaiting reply. The following entries have not
+ // been sent yet.
+ pendingCmds map[string][]pendingUpstreamCommand
+
+ gotMotd bool
+}
+
+func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
+ logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
+
+ dialer := net.Dialer{Timeout: connectTimeout}
+
+ u, err := network.URL()
+ if err != nil {
+ return nil, err
+ }
+
+ var netConn net.Conn
+ switch u.Scheme {
+ case "ircs":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6697"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to TLS server at address %q", addr)
+
+ tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}}
+ if network.SASL.Mechanism == "EXTERNAL" {
+ if network.SASL.External.CertBlob == nil {
+ return nil, fmt.Errorf("missing certificate for authentication")
+ }
+ if network.SASL.External.PrivKeyBlob == nil {
+ return nil, fmt.Errorf("missing private key for authentication")
+ }
+ key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse private key: %v", err)
+ }
+ tlsConfig.Certificates = []tls.Certificate{
+ {
+ Certificate: [][]byte{network.SASL.External.CertBlob},
+ PrivateKey: key.(crypto.PrivateKey),
+ },
+ }
+ logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
+ }
+
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+
+ // Don't do the TLS handshake immediately, because we need to register
+ // the new connection with identd ASAP.
+ netConn = tls.Client(netConn, tlsConfig)
+ case "irc":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6667"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to plain-text server at address %q", addr)
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+ case "irc+unix", "unix":
+ logger.Printf("connecting to Unix socket at path %q", u.Path)
+ netConn, err = dialer.DialContext(ctx, "unix", u.Path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
+ }
+ default:
+ return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
+ }
+
+ options := connOptions{
+ Logger: logger,
+ RateLimitDelay: upstreamMessageDelay,
+ RateLimitBurst: upstreamMessageBurst,
+ }
+
+ uc := &upstreamConn{
+ conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
+ network: network,
+ user: network.user,
+ channels: upstreamChannelCasemapMap{newCasemapMap(0)},
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ batches: make(map[string]batch),
+ availableChannelTypes: stdChannelTypes,
+ availableChannelModes: stdChannelModes,
+ availableMemberships: stdMemberships,
+ isupport: make(map[string]*string),
+ pendingCmds: make(map[string][]pendingUpstreamCommand),
+ monitored: monitorCasemapMap{newCasemapMap(0)},
+ }
+ return uc, nil
+}
+
+func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
+ uc.network.forEachDownstream(f)
+}
+
+func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if id != 0 && id != dc.id {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
+ for _, dc := range uc.user.downstreamConns {
+ if dc.id == id {
+ return dc
+ }
+ }
+ return nil
+}
+
+func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ return nil, fmt.Errorf("unknown channel %q", name)
+ }
+ return ch, nil
+}
+
+func (uc *upstreamConn) isChannel(entity string) bool {
+ return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0]))
+}
+
+func (uc *upstreamConn) isOurNick(nick string) bool {
+ return uc.nickCM == uc.network.casemap(nick)
+}
+
+func (uc *upstreamConn) abortPendingCommands() {
+ for _, l := range uc.pendingCmds {
+ for _, pendingCmd := range l {
+ dc := uc.downstreamByID(pendingCmd.downstreamID)
+ if dc == nil {
+ continue
+ }
+
+ switch pendingCmd.msg.Command {
+ case "LIST":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Command aborted"},
+ })
+ case "WHO":
+ mask := "*"
+ if len(pendingCmd.msg.Params) > 0 {
+ mask = pendingCmd.msg.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "Command aborted"},
+ })
+ case "AUTHENTICATE":
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ case "REGISTER", "VERIFY":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "FAIL",
+ Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"},
+ })
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command))
+ }
+ }
+ }
+
+ uc.pendingCmds = make(map[string][]pendingUpstreamCommand)
+}
+
+func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return
+ }
+ uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg)
+}
+
+func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
+ switch msg.Command {
+ case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY":
+ // Supported
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", msg.Command))
+ }
+
+ uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{
+ downstreamID: dc.id,
+ msg: msg,
+ })
+
+ if len(uc.pendingCmds[msg.Command]) == 1 {
+ uc.sendNextPendingCommand(msg.Command)
+ }
+}
+
+func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return nil, nil
+ }
+
+ pendingCmd := uc.pendingCmds[cmd][0]
+ return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg
+}
+
+func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) {
+ dc, msg := uc.currentPendingCommand(cmd)
+
+ if len(uc.pendingCmds[cmd]) > 0 {
+ copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:])
+ uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1]
+ }
+
+ uc.sendNextPendingCommand(cmd)
+
+ return dc, msg
+}
+
+func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) {
+ for cmd := range uc.pendingCmds {
+ // We can't cancel the currently running command stored in
+ // uc.pendingCmds[cmd][0]
+ for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- {
+ if uc.pendingCmds[cmd][i].downstreamID == downstreamID {
+ uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...)
+ }
+ }
+ }
+}
+
+func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) {
+ memberships := make(memberships, 0, 4)
+ i := 0
+ for _, m := range uc.availableMemberships {
+ if i >= len(s) {
+ break
+ }
+ if s[i] == m.Prefix {
+ memberships = append(memberships, m)
+ i++
+ }
+ }
+ return &memberships, s[i:]
+}
+
+func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ var label string
+ if l, ok := msg.GetTag("label"); ok {
+ label = l
+ delete(msg.Tags, "label")
+ }
+
+ var msgBatch *batch
+ if batchName, ok := msg.GetTag("batch"); ok {
+ b, ok := uc.batches[batchName]
+ if !ok {
+ return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
+ }
+ msgBatch = &b
+ if label == "" {
+ label = msgBatch.Label
+ }
+ delete(msg.Tags, "batch")
+ }
+
+ var downstreamID uint64 = 0
+ if label != "" {
+ var labelOffset uint64
+ n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset)
+ if err == nil && n < 2 {
+ err = errors.New("not enough arguments")
+ }
+ if err != nil {
+ return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
+ }
+ }
+
+ if _, ok := msg.Tags["time"]; !ok {
+ msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ }
+
+ switch msg.Command {
+ case "PING":
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PONG",
+ Params: msg.Params,
+ })
+ return nil
+ case "NOTICE", "PRIVMSG", "TAGMSG":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var entity, text string
+ if msg.Command != "TAGMSG" {
+ if err := parseMessageParams(msg, &entity, &text); err != nil {
+ return err
+ }
+ } else {
+ if err := parseMessageParams(msg, &entity); err != nil {
+ return err
+ }
+ }
+
+ if msg.Prefix.Name == serviceNick {
+ uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg)
+ break
+ }
+ if entity == serviceNick {
+ uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg)
+ break
+ }
+
+ if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
+ uc.produce("", msg, nil)
+ } else { // regular user message
+ target := entity
+ if uc.isOurNick(target) {
+ target = msg.Prefix.Name
+ }
+
+ ch := uc.network.channels.Value(target)
+ if ch != nil && msg.Command != "TAGMSG" {
+ if ch.Detached {
+ uc.handleDetachedMessage(ctx, ch, msg)
+ }
+
+ highlight := uc.network.isHighlight(msg)
+ if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) {
+ uc.updateChannelAutoDetach(target)
+ }
+ }
+
+ uc.produce(target, msg, nil)
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, nil, &subCmd); err != nil {
+ return err
+ }
+ subCmd = strings.ToUpper(subCmd)
+ subParams := msg.Params[2:]
+ switch subCmd {
+ case "LS":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := subParams[len(subParams)-1]
+ more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
+
+ uc.handleSupportedCaps(caps)
+
+ if more {
+ break // wait to receive all capabilities
+ }
+
+ uc.requestCaps()
+
+ if uc.requestSASL() {
+ break // we'll send CAP END after authentication is completed
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ case "ACK", "NAK":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, name := range caps {
+ if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil {
+ return err
+ }
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ case "NEW":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ uc.handleSupportedCaps(subParams[0])
+ uc.requestCaps()
+ case "DEL":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, c := range caps {
+ delete(uc.supportedCaps, c)
+ delete(uc.caps, c)
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+ }
+ case "AUTHENTICATE":
+ if uc.saslClient == nil {
+ return fmt.Errorf("received unexpected AUTHENTICATE message")
+ }
+
+ // TODO: if a challenge is 400 bytes long, buffer it
+ var challengeStr string
+ if err := parseMessageParams(msg, &challengeStr); err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ var challenge []byte
+ if challengeStr != "+" {
+ var err error
+ challenge, err = base64.StdEncoding.DecodeString(challengeStr)
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+ }
+
+ var resp []byte
+ var err error
+ if !uc.saslStarted {
+ _, resp, err = uc.saslClient.Start()
+ uc.saslStarted = true
+ } else {
+ resp, err = uc.saslClient.Next(challenge)
+ }
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ // <= instead of < because we need to send a final empty response if
+ // the last chunk is exactly 400 bytes long
+ for i := 0; i <= len(resp); i += maxSASLLength {
+ j := i + maxSASLLength
+ if j > len(resp) {
+ j = len(resp)
+ }
+
+ chunk := resp[i:j]
+
+ var respStr = "+"
+ if len(chunk) != 0 {
+ respStr = base64.StdEncoding.EncodeToString(chunk)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{respStr},
+ })
+ }
+ case irc.RPL_LOGGEDIN:
+ if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil {
+ return err
+ }
+ uc.logger.Printf("logged in with account %q", uc.account)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.RPL_LOGGEDOUT:
+ uc.account = ""
+ uc.logger.Printf("logged out")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
+ var info string
+ if err := parseMessageParams(msg, nil, &info); err != nil {
+ return err
+ }
+ switch msg.Command {
+ case irc.ERR_NICKLOCKED:
+ uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
+ case irc.ERR_SASLFAIL:
+ uc.logger.Printf("SASL authentication failed: %v", info)
+ case irc.ERR_SASLTOOLONG:
+ uc.logger.Printf("SASL message too long: %v", info)
+ }
+
+ uc.saslClient = nil
+ uc.saslStarted = false
+
+ if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
+ if msg.Command == irc.RPL_SASLSUCCESS {
+ uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword)
+ }
+
+ dc.endSASL(msg)
+ }
+
+ if !uc.registered {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil {
+ if msg.Command == "REGISTER" {
+ var account, password string
+ if err := parseMessageParams(msg, nil, &account); err != nil {
+ return err
+ }
+ if err := parseMessageParams(cmd, nil, nil, &password); err != nil {
+ return err
+ }
+ uc.network.autoSaveSASLPlain(ctx, account, password)
+ }
+
+ dc.SendMessage(msg)
+ }
+ case irc.RPL_WELCOME:
+ if err := parseMessageParams(msg, &uc.nick); err != nil {
+ return err
+ }
+
+ uc.registered = true
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("connection registered with nick %q", uc.nick)
+
+ if uc.network.channels.Len() > 0 {
+ var channels, keys []string
+ for _, entry := range uc.network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, ch.Name)
+ keys = append(keys, ch.Key)
+ }
+
+ for _, msg := range join(channels, keys) {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+ case irc.RPL_MYINFO:
+ if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
+ return err
+ }
+ case irc.RPL_ISUPPORT:
+ if err := parseMessageParams(msg, nil, nil); err != nil {
+ return err
+ }
+
+ var downstreamIsupport []string
+ for _, token := range msg.Params[1 : len(msg.Params)-1] {
+ parameter := token
+ var negate, hasValue bool
+ var value string
+ if strings.HasPrefix(token, "-") {
+ negate = true
+ token = token[1:]
+ } else if i := strings.IndexByte(token, '='); i >= 0 {
+ parameter = token[:i]
+ value = token[i+1:]
+ hasValue = true
+ }
+
+ if hasValue {
+ uc.isupport[parameter] = &value
+ } else if !negate {
+ uc.isupport[parameter] = nil
+ } else {
+ delete(uc.isupport, parameter)
+ }
+
+ var err error
+ switch parameter {
+ case "CASEMAPPING":
+ casemap, ok := parseCasemappingToken(value)
+ if !ok {
+ casemap = casemapRFC1459
+ }
+ uc.network.updateCasemapping(casemap)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.casemapIsSet = true
+ case "CHANMODES":
+ if !negate {
+ err = uc.handleChanModes(value)
+ } else {
+ uc.availableChannelModes = stdChannelModes
+ }
+ case "CHANTYPES":
+ if !negate {
+ uc.availableChannelTypes = value
+ } else {
+ uc.availableChannelTypes = stdChannelTypes
+ }
+ case "PREFIX":
+ if !negate {
+ err = uc.handleMemberships(value)
+ } else {
+ uc.availableMemberships = stdMemberships
+ }
+ }
+ if err != nil {
+ return err
+ }
+
+ if passthroughIsupport[parameter] {
+ downstreamIsupport = append(downstreamIsupport, token)
+ }
+ }
+
+ uc.updateMonitor()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil {
+ return
+ }
+ msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport)
+ for _, msg := range msgs {
+ dc.SendMessage(msg)
+ }
+ })
+ case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD:
+ if !uc.casemapIsSet {
+ // upstream did not send any CASEMAPPING token, thus
+ // we assume it implements the old RFCs with rfc1459.
+ uc.casemapIsSet = true
+ uc.network.updateCasemapping(casemapRFC1459)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ if !uc.gotMotd {
+ // Ignore the initial MOTD upon connection, but forward
+ // subsequent MOTD messages downstream
+ uc.gotMotd = true
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case "BATCH":
+ var tag string
+ if err := parseMessageParams(msg, &tag); err != nil {
+ return err
+ }
+
+ if strings.HasPrefix(tag, "+") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; ok {
+ return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
+ }
+ var batchType string
+ if err := parseMessageParams(msg, nil, &batchType); err != nil {
+ return err
+ }
+ label := label
+ if label == "" && msgBatch != nil {
+ label = msgBatch.Label
+ }
+ uc.batches[tag] = batch{
+ Type: batchType,
+ Params: msg.Params[2:],
+ Outer: msgBatch,
+ Label: label,
+ }
+ } else if strings.HasPrefix(tag, "-") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; !ok {
+ return fmt.Errorf("unknown BATCH reference tag: %q", tag)
+ }
+ delete(uc.batches, tag)
+ } else {
+ return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
+ }
+ case "NICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newNick string
+ if err := parseMessageParams(msg, &newNick); err != nil {
+ return err
+ }
+
+ me := false
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
+ me = true
+ uc.nick = newNick
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ memberships := ch.Members.Value(msg.Prefix.Name)
+ if memberships != nil {
+ ch.Members.Delete(msg.Prefix.Name)
+ ch.Members.SetValue(newNick, memberships)
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if !me {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateNick()
+ })
+ uc.updateMonitor()
+ }
+ case "SETNAME":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newRealname string
+ if err := parseMessageParams(msg, &newRealname); err != nil {
+ return err
+ }
+
+ // TODO: consider appending this message to logs
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname)
+ uc.realname = newRealname
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateRealname()
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case "JOIN":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("joined channel %q", ch)
+ members := membershipsCasemapMap{newCasemapMap(0)}
+ members.casemap = uc.network.casemap
+ uc.channels.SetValue(ch, &upstreamChannel{
+ Name: ch,
+ conn: uc,
+ Members: members,
+ })
+ uc.updateChannelAutoDetach(ch)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MODE",
+ Params: []string{ch},
+ })
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.SetValue(msg.Prefix.Name, &memberships{})
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "PART":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("parted channel %q", ch)
+ uch := uc.channels.Value(ch)
+ if uch != nil {
+ uc.channels.Delete(ch)
+ uch.updateAutoDetach(0)
+ }
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(msg.Prefix.Name)
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "KICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channel, user string
+ if err := parseMessageParams(msg, &channel, &user); err != nil {
+ return err
+ }
+
+ if uc.isOurNick(user) {
+ uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
+ uc.channels.Delete(channel)
+ } else {
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(user)
+ }
+
+ uc.produce(channel, msg, nil)
+ case "QUIT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("quit")
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if ch.Members.Has(msg.Prefix.Name) {
+ ch.Members.Delete(msg.Prefix.Name)
+
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if msg.Prefix.Name != uc.nick {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
+ var name, topic string
+ if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if msg.Command == irc.RPL_TOPIC {
+ ch.Topic = topic
+ } else {
+ ch.Topic = ""
+ }
+ case "TOPIC":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ ch.Topic = msg.Params[1]
+ ch.TopicWho = msg.Prefix.Copy()
+ ch.TopicTime = time.Now() // TODO use msg.Tags["time"]
+ } else {
+ ch.Topic = ""
+ }
+ uc.produce(ch.Name, msg, nil)
+ case "MODE":
+ var name, modeStr string
+ if err := parseMessageParams(msg, &name, &modeStr); err != nil {
+ return err
+ }
+
+ if !uc.isChannel(name) { // user mode change
+ if name != uc.nick {
+ return fmt.Errorf("received MODE message for unknown nick %q", name)
+ }
+
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ } else { // channel mode change
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+
+ needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:])
+ if err != nil {
+ return err
+ }
+
+ uc.appendLog(ch.Name, msg)
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := make([]string, len(msg.Params))
+ params[0] = dc.marshalEntity(uc.network, name)
+ params[1] = modeStr
+
+ copy(params[2:], msg.Params[2:])
+ for i, modeParam := range params[2:] {
+ if _, ok := needMarshaling[i]; ok {
+ params[2+i] = dc.marshalEntity(uc.network, modeParam)
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "MODE",
+ Params: params,
+ })
+ })
+ }
+ }
+ case irc.RPL_UMODEIS:
+ if err := parseMessageParams(msg, nil); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ uc.modes = ""
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ case irc.RPL_CHANNELMODEIS:
+ var channel string
+ if err := parseMessageParams(msg, nil, &channel); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 2 {
+ modeStr = msg.Params[2]
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstMode := ch.modes == nil
+ ch.modes = make(map[byte]string)
+ if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil {
+ return err
+ }
+
+ c := uc.network.channels.Value(channel)
+ if firstMode && (c == nil || !c.Detached) {
+ modeStr, modeParams := ch.modes.Format()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ })
+ }
+ case rpl_creationtime:
+ var channel, creationTime string
+ if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstCreationTime := ch.creationTime == ""
+ ch.creationTime = creationTime
+
+ c := uc.network.channels.Value(channel)
+ if firstCreationTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime},
+ })
+ })
+ }
+ case rpl_topicwhotime:
+ var channel, who, timeStr string
+ if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstTopicWhoTime := ch.TopicWho == nil
+ ch.TopicWho = irc.ParsePrefix(who)
+ sec, err := strconv.ParseInt(timeStr, 10, 64)
+ if err != nil {
+ return fmt.Errorf("failed to parse topic time: %v", err)
+ }
+ ch.TopicTime = time.Unix(sec, 0)
+
+ c := uc.network.channels.Value(channel)
+ if firstTopicWhoTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{
+ dc.nick,
+ dc.marshalEntity(uc.network, ch.Name),
+ topicWho.String(),
+ timeStr,
+ },
+ })
+ })
+ }
+ case irc.RPL_LIST:
+ var channel, clients, topic string
+ if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LIST,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic},
+ })
+ case irc.RPL_LISTEND:
+ dc, cmd := uc.dequeueCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "End of /LIST"},
+ })
+ case irc.RPL_NAMREPLY:
+ var name, statusStr, members string
+ if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+ members := splitSpace(members)
+ for i, member := range members {
+ memberships, nick := uc.parseMembershipPrefix(member)
+ members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick)
+ }
+ memberStr := strings.Join(members, " ")
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, statusStr, channel, memberStr},
+ })
+ })
+ return nil
+ }
+
+ status, err := parseChannelStatus(statusStr)
+ if err != nil {
+ return err
+ }
+ ch.Status = status
+
+ for _, s := range splitSpace(members) {
+ memberships, nick := uc.parseMembershipPrefix(s)
+ ch.Members.SetValue(nick, memberships)
+ }
+ case irc.RPL_ENDOFNAMES:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, channel, "End of /NAMES list"},
+ })
+ })
+ return nil
+ }
+
+ if ch.complete {
+ return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
+ }
+ ch.complete = true
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ forwardChannel(ctx, dc, ch)
+ })
+ }
+ case irc.RPL_WHOREPLY:
+ var channel, username, host, server, nick, flags, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ if channel != "*" {
+ channel = dc.marshalEntity(uc.network, channel)
+ }
+ nick = dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing},
+ })
+ case rpl_whospcrpl:
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ // Only supported in single-upstream mode, so forward as-is
+ dc.SendMessage(msg)
+ case irc.RPL_ENDOFWHO:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.dequeueCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ mask := "*"
+ if len(cmd.Params) > 0 {
+ mask = cmd.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "End of /WHO list"},
+ })
+ case irc.RPL_WHOISUSER:
+ var nick, username, host, realname string
+ if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, nick, username, host, "*", realname},
+ })
+ })
+ case irc.RPL_WHOISSERVER:
+ var nick, server, serverInfo string
+ if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, nick, server, serverInfo},
+ })
+ })
+ case irc.RPL_WHOISOPERATOR:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, nick, "is an IRC operator"},
+ })
+ })
+ case irc.RPL_WHOISIDLE:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ params := []string{dc.nick, nick}
+ params = append(params, msg.Params[2:]...)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISIDLE,
+ Params: params,
+ })
+ })
+ case irc.RPL_WHOISCHANNELS:
+ var nick, channelList string
+ if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
+ return err
+ }
+ channels := splitSpace(channelList)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ channelList := make([]string, len(channels))
+ for i, channel := range channels {
+ prefix, channel := uc.parseMembershipPrefix(channel)
+ channel = dc.marshalEntity(uc.network, channel)
+ channelList[i] = prefix.Format(dc) + channel
+ }
+ channels := strings.Join(channelList, " ")
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISCHANNELS,
+ Params: []string{dc.nick, nick, channels},
+ })
+ })
+ case irc.RPL_ENDOFWHOIS:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, nick, "End of /WHOIS list"},
+ })
+ })
+ case "INVITE":
+ var nick, channel string
+ if err := parseMessageParams(msg, &nick, &channel); err != nil {
+ return err
+ }
+
+ weAreInvited := uc.isOurNick(nick)
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !weAreInvited && !dc.caps["invite-notify"] {
+ return
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "INVITE",
+ Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_INVITING:
+ var nick, channel string
+ if err := parseMessageParams(msg, nil, &nick, &channel); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_INVITING,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE:
+ var targetsStr string
+ if err := parseMessageParams(msg, nil, &targetsStr); err != nil {
+ return err
+ }
+ targets := strings.Split(targetsStr, ",")
+
+ online := msg.Command == irc.RPL_MONONLINE
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ uc.monitored.SetValue(prefix.Name, online)
+ }
+
+ // Check if the nick we want is now free
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if !online && uc.nickCM != wantNickCM {
+ found := false
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if uc.network.casemap(prefix.Name) == wantNickCM {
+ found = true
+ break
+ }
+ }
+ if found {
+ uc.logger.Printf("desired nick %q is now available", wantNick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{wantNick},
+ })
+ }
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if dc.monitored.Has(prefix.Name) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ })
+ case irc.ERR_MONLISTFULL:
+ var limit, targetsStr string
+ if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil {
+ return err
+ }
+
+ targets := strings.Split(targetsStr, ",")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ if dc.monitored.Has(target) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, limit, target},
+ })
+ }
+ }
+ })
+ case irc.RPL_AWAY:
+ var nick, reason string
+ if err := parseMessageParams(msg, nil, &nick, &reason); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_AWAY,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason},
+ })
+ })
+ case "AWAY", "ACCOUNT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST:
+ var channel, mask string
+ if err := parseMessageParams(msg, nil, &channel, &mask); err != nil {
+ return err
+ }
+ var addNick, addTime string
+ if len(msg.Params) >= 5 {
+ addNick = msg.Params[3]
+ addTime = msg.Params[4]
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, channel)
+
+ var params []string
+ if addNick != "" && addTime != "" {
+ addNick := dc.marshalEntity(uc.network, addNick)
+ params = []string{dc.nick, channel, mask, addNick, addTime}
+ } else {
+ params = []string{dc.nick, channel, mask}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST:
+ var channel, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ upstreamChannel := dc.marshalEntity(uc.network, channel)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, upstreamChannel, trailing},
+ })
+ })
+ case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN:
+ var command, reason string
+ if err := parseMessageParams(msg, nil, &command, &reason); err != nil {
+ return err
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, command, reason},
+ })
+ })
+ case "FAIL":
+ var command, code string
+ if err := parseMessageParams(msg, &command, &code); err != nil {
+ return err
+ }
+
+ if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" {
+ return registrationError{msg}
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case "ACK":
+ // Ignore
+ case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
+ // Ignore
+ case irc.RPL_YOURHOST, irc.RPL_CREATED:
+ // Ignore
+ case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
+ fallthrough
+ case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
+ fallthrough
+ case rpl_localusers, rpl_globalusers:
+ fallthrough
+ case irc.RPL_MOTDSTART, irc.RPL_MOTD:
+ // Ignore these messages if they're part of the initial registration
+ // message burst. Forward them if the user explicitly asked for them.
+ if !uc.gotMotd {
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_LISTSTART:
+ // Ignore
+ case "ERROR":
+ var text string
+ if err := parseMessageParams(msg, &text); err != nil {
+ return err
+ }
+ return fmt.Errorf("fatal server error: %v", text)
+ case irc.ERR_NICKNAMEINUSE:
+ // At this point, we haven't received ISUPPORT so we don't know the
+ // maximum nickname length or whether the server supports MONITOR. Many
+ // servers have NICKLEN=30 so let's just use that.
+ if !uc.registered && len(uc.nick)+1 < 30 {
+ uc.nick = uc.nick + "_"
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ return nil
+ }
+ fallthrough
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP:
+ if !uc.registered {
+ return registrationError{msg}
+ }
+ fallthrough
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ // best effort marshaling for unknown messages, replies and errors:
+ // most numerics start with the user nick, marshal it if that's the case
+ // otherwise, conservately keep the params without marshaling
+ params := msg.Params
+ if _, err := strconv.Atoi(msg.Command); err == nil { // numeric
+ if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) {
+ params[0] = dc.nick
+ }
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ }
+ return nil
+}
+
+func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) {
+ if uc.network.detachedMessageNeedsRelay(ch, msg) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.relayDetachedMessage(uc.network, msg)
+ })
+ }
+ if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
+ uc.network.attach(ctx, ch)
+ if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
+ }
+ }
+}
+
+func (uc *upstreamConn) handleChanModes(s string) error {
+ parts := strings.SplitN(s, ",", 5)
+ if len(parts) < 4 {
+ return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s)
+ }
+ modes := make(map[byte]channelModeType)
+ for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
+ for j := 0; j < len(parts[i]); j++ {
+ mode := parts[i][j]
+ modes[mode] = mt
+ }
+ }
+ uc.availableChannelModes = modes
+ return nil
+}
+
+func (uc *upstreamConn) handleMemberships(s string) error {
+ if s == "" {
+ uc.availableMemberships = nil
+ return nil
+ }
+
+ if s[0] != '(' {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ sep := strings.IndexByte(s, ')')
+ if sep < 0 || len(s) != sep*2 {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ memberships := make([]membership, len(s)/2-1)
+ for i := range memberships {
+ memberships[i] = membership{
+ Mode: s[i+1],
+ Prefix: s[sep+i+1],
+ }
+ }
+ uc.availableMemberships = memberships
+ return nil
+}
+
+func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
+ caps := strings.Fields(capsStr)
+ for _, s := range caps {
+ kv := strings.SplitN(s, "=", 2)
+ k := strings.ToLower(kv[0])
+ var v string
+ if len(kv) == 2 {
+ v = kv[1]
+ }
+ uc.supportedCaps[k] = v
+ }
+}
+
+func (uc *upstreamConn) requestCaps() {
+ var requestCaps []string
+ for c := range permanentUpstreamCaps {
+ if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] {
+ requestCaps = append(requestCaps, c)
+ }
+ }
+
+ if len(requestCaps) == 0 {
+ return
+ }
+
+ uc.SendMessage(context.TODO(), &irc.Message{
+ Command: "CAP",
+ Params: []string{"REQ", strings.Join(requestCaps, " ")},
+ })
+}
+
+func (uc *upstreamConn) supportsSASL(mech string) bool {
+ v, ok := uc.supportedCaps["sasl"]
+ if !ok {
+ return false
+ }
+
+ if v == "" {
+ return true
+ }
+
+ mechanisms := strings.Split(v, ",")
+ for _, mech := range mechanisms {
+ if strings.EqualFold(mech, mech) {
+ return true
+ }
+ }
+ return false
+}
+
+func (uc *upstreamConn) requestSASL() bool {
+ if uc.network.SASL.Mechanism == "" {
+ return false
+ }
+ return uc.supportsSASL(uc.network.SASL.Mechanism)
+}
+
+func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error {
+ uc.caps[name] = ok
+
+ switch name {
+ case "sasl":
+ if !uc.requestSASL() {
+ return nil
+ }
+ if !ok {
+ uc.logger.Printf("server refused to acknowledge the SASL capability")
+ return nil
+ }
+
+ auth := &uc.network.SASL
+ switch auth.Mechanism {
+ case "PLAIN":
+ uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
+ uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
+ case "EXTERNAL":
+ uc.logger.Printf("starting SASL EXTERNAL authentication")
+ uc.saslClient = sasl.NewExternalClient("")
+ default:
+ return fmt.Errorf("unsupported SASL mechanism %q", name)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{auth.Mechanism},
+ })
+ default:
+ if permanentUpstreamCaps[name] {
+ break
+ }
+ uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
+ }
+ return nil
+}
+
+func splitSpace(s string) []string {
+ return strings.FieldsFunc(s, func(r rune) bool {
+ return r == ' '
+ })
+}
+
+func (uc *upstreamConn) register(ctx context.Context) {
+ uc.nick = GetNick(&uc.user.User, &uc.network.Network)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.username = GetUsername(&uc.user.User, &uc.network.Network)
+ uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"LS", "302"},
+ })
+
+ if uc.network.Pass != "" {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PASS",
+ Params: []string{uc.network.Pass},
+ })
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "USER",
+ Params: []string{uc.username, "0", "*", uc.realname},
+ })
+}
+
+func (uc *upstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := uc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error {
+ for !uc.registered {
+ msg, err := uc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read message: %v", err)
+ }
+
+ if err := uc.handleMessage(ctx, msg); err != nil {
+ if _, ok := err.(registrationError); ok {
+ return err
+ } else {
+ msg.Tags = nil // prevent message tags from cluttering logs
+ return fmt.Errorf("failed to handle message %q: %v", msg, err)
+ }
+ }
+ }
+
+ for _, command := range uc.network.ConnectCommands {
+ m, err := irc.ParseMessage(command)
+ if err != nil {
+ uc.logger.Printf("failed to parse connect command %q: %v", command, err)
+ } else {
+ uc.SendMessage(ctx, m)
+ }
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := uc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventUpstreamMessage{msg, uc}
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
+ if !uc.caps["message-tags"] {
+ msg = msg.Copy()
+ msg.Tags = nil
+ }
+
+ uc.conn.SendMessage(ctx, msg)
+}
+
+func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
+ if uc.caps["labeled-response"] {
+ if msg.Tags == nil {
+ msg.Tags = make(map[string]irc.TagValue)
+ }
+ msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
+ uc.nextLabelID++
+ }
+ uc.SendMessage(ctx, msg)
+}
+
+// appendLog appends a message to the log file.
+//
+// The internal message ID is returned. If the message isn't recorded in the
+// log file, an empty string is returned.
+func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) {
+ if uc.user.msgStore == nil {
+ return ""
+ }
+
+ // Don't store messages with a server mask target
+ if strings.HasPrefix(entity, "$") {
+ return ""
+ }
+
+ entityCM := uc.network.casemap(entity)
+ if entityCM == "nickserv" {
+ // The messages sent/received from NickServ may contain
+ // security-related information (like passwords). Don't store these.
+ return ""
+ }
+
+ if !uc.network.delivered.HasTarget(entity) {
+ // This is the first message we receive from this target. Save the last
+ // message ID in delivery receipts, so that we can send the new message
+ // in the backlog if an offline client reconnects.
+ lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now())
+ if err != nil {
+ uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
+ return ""
+ }
+
+ uc.network.delivered.ForEachClient(func(clientName string) {
+ uc.network.delivered.StoreID(entity, clientName, lastID)
+ })
+ }
+
+ msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg)
+ if err != nil {
+ uc.logger.Printf("failed to append message to store: %v", err)
+ return ""
+ }
+
+ return msgID
+}
+
+// produce appends a message to the logs and forwards it to connected downstream
+// connections.
+//
+// If origin is not nil and origin doesn't support echo-message, the message is
+// forwarded to all connections except origin.
+func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) {
+ var msgID string
+ if target != "" {
+ msgID = uc.appendLog(target, msg)
+ }
+
+ // Don't forward messages if it's a detached channel
+ ch := uc.network.channels.Value(target)
+ detached := ch != nil && ch.Detached
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !detached && (dc != origin || dc.caps["echo-message"]) {
+ dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID)
+ } else {
+ dc.advanceMessageWithID(msg, msgID)
+ }
+ })
+}
+
+func (uc *upstreamConn) updateAway() {
+ ctx := context.TODO()
+
+ away := true
+ uc.forEachDownstream(func(*downstreamConn) {
+ away = false
+ })
+ if away == uc.away {
+ return
+ }
+ if away {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ Params: []string{"Auto away"},
+ })
+ } else {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ })
+ }
+ uc.away = away
+}
+
+func (uc *upstreamConn) updateChannelAutoDetach(name string) {
+ uch := uc.channels.Value(name)
+ if uch == nil {
+ return
+ }
+ ch := uc.network.channels.Value(name)
+ if ch == nil || ch.Detached {
+ return
+ }
+ uch.updateAutoDetach(ch.DetachAfter)
+}
+
+func (uc *upstreamConn) updateMonitor() {
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return
+ }
+
+ ctx := context.TODO()
+
+ add := make(map[string]struct{})
+ var addList []string
+ seen := make(map[string]struct{})
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for targetCM := range dc.monitored.innerMap {
+ if !uc.monitored.Has(targetCM) {
+ if _, ok := add[targetCM]; !ok {
+ addList = append(addList, targetCM)
+ add[targetCM] = struct{}{}
+ }
+ } else {
+ seen[targetCM] = struct{}{}
+ }
+ }
+ })
+
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
+ addList = append(addList, wantNickCM)
+ add[wantNickCM] = struct{}{}
+ }
+
+ removeAll := true
+ var removeList []string
+ for targetCM, entry := range uc.monitored.innerMap {
+ if _, ok := seen[targetCM]; ok {
+ removeAll = false
+ } else {
+ removeList = append(removeList, entry.originalKey)
+ }
+ }
+
+ // TODO: better handle the case where len(uc.monitored) + len(addList)
+ // exceeds the limit, probably by immediately sending ERR_MONLISTFULL?
+
+ if removeAll && len(addList) == 0 && len(removeList) > 0 {
+ // Optimization when the last MONITOR-aware downstream disconnects
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{"C"},
+ })
+ } else {
+ msgs := generateMonitor("-", removeList)
+ msgs = append(msgs, generateMonitor("+", addList)...)
+ for _, msg := range msgs {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+
+ for _, target := range removeList {
+ uc.monitored.Delete(target)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/binary"
+ "encoding/hex"
+ "fmt"
+ "math/big"
+ "net"
+ "sort"
+ "strings"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+type event interface{}
+
+type eventUpstreamMessage struct {
+ msg *irc.Message
+ uc *upstreamConn
+}
+
+type eventUpstreamConnectionError struct {
+ net *network
+ err error
+}
+
+type eventUpstreamConnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamDisconnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamError struct {
+ uc *upstreamConn
+ err error
+}
+
+type eventDownstreamMessage struct {
+ msg *irc.Message
+ dc *downstreamConn
+}
+
+type eventDownstreamConnected struct {
+ dc *downstreamConn
+}
+
+type eventDownstreamDisconnected struct {
+ dc *downstreamConn
+}
+
+type eventChannelDetach struct {
+ uc *upstreamConn
+ name string
+}
+
+type eventBroadcast struct {
+ msg *irc.Message
+}
+
+type eventStop struct{}
+
+type eventUserUpdate struct {
+ password *string
+ admin *bool
+ done chan error
+}
+
+type deliveredClientMap map[string]string // client name -> msg ID
+
+type deliveredStore struct {
+ m deliveredCasemapMap
+}
+
+func newDeliveredStore() deliveredStore {
+ return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
+}
+
+func (ds deliveredStore) HasTarget(target string) bool {
+ return ds.m.Value(target) != nil
+}
+
+func (ds deliveredStore) LoadID(target, clientName string) string {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ return ""
+ }
+ return clients[clientName]
+}
+
+func (ds deliveredStore) StoreID(target, clientName, msgID string) {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ clients = make(deliveredClientMap)
+ ds.m.SetValue(target, clients)
+ }
+ clients[clientName] = msgID
+}
+
+func (ds deliveredStore) ForEachTarget(f func(target string)) {
+ for _, entry := range ds.m.innerMap {
+ f(entry.originalKey)
+ }
+}
+
+func (ds deliveredStore) ForEachClient(f func(clientName string)) {
+ clients := make(map[string]struct{})
+ for _, entry := range ds.m.innerMap {
+ delivered := entry.value.(deliveredClientMap)
+ for clientName := range delivered {
+ clients[clientName] = struct{}{}
+ }
+ }
+
+ for clientName := range clients {
+ f(clientName)
+ }
+}
+
+type network struct {
+ Network
+ user *user
+ logger Logger
+ stopped chan struct{}
+
+ conn *upstreamConn
+ channels channelCasemapMap
+ delivered deliveredStore
+ lastError error
+ casemap casemapping
+}
+
+func newNetwork(user *user, record *Network, channels []Channel) *network {
+ logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
+
+ m := channelCasemapMap{newCasemapMap(0)}
+ for _, ch := range channels {
+ ch := ch
+ m.SetValue(ch.Name, &ch)
+ }
+
+ return &network{
+ Network: *record,
+ user: user,
+ logger: logger,
+ stopped: make(chan struct{}),
+ channels: m,
+ delivered: newDeliveredStore(),
+ casemap: casemapRFC1459,
+ }
+}
+
+func (net *network) forEachDownstream(f func(*downstreamConn)) {
+ net.user.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ if dc.network != nil && dc.network != net {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (net *network) isStopped() bool {
+ select {
+ case <-net.stopped:
+ return true
+ default:
+ return false
+ }
+}
+
+func userIdent(u *User) string {
+ // The ident is a string we will send to upstream servers in clear-text.
+ // For privacy reasons, make sure it doesn't expose any meaningful user
+ // metadata. We just use the base64-encoded hashed ID, so that people don't
+ // start relying on the string being an integer or following a pattern.
+ var b [64]byte
+ binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
+ h := sha256.Sum256(b[:])
+ return hex.EncodeToString(h[:16])
+}
+
+func (net *network) run() {
+ if !net.Enabled {
+ return
+ }
+
+ var lastTry time.Time
+ backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter)
+ for {
+ if net.isStopped() {
+ return
+ }
+
+ delay := backoff.Next() - time.Now().Sub(lastTry)
+ if delay > 0 {
+ net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
+ time.Sleep(delay)
+ }
+ lastTry = time.Now()
+
+
+ uc, err := connectToUpstream(context.TODO(), net)
+ if err != nil {
+ net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
+ continue
+ }
+
+ uc.register(context.TODO())
+ if err := uc.runUntilRegistered(context.TODO()); err != nil {
+ text := err.Error()
+ temp := true
+ if regErr, ok := err.(registrationError); ok {
+ text = regErr.Reason()
+ temp = regErr.Temporary()
+ }
+ uc.logger.Printf("failed to register: %v", text)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
+ uc.Close()
+ if !temp {
+ return
+ }
+ continue
+ }
+
+ // TODO: this is racy with net.stopped. If the network is stopped
+ // before the user goroutine receives eventUpstreamConnected, the
+ // connection won't be closed.
+ net.user.events <- eventUpstreamConnected{uc}
+ if err := uc.readMessages(net.user.events); err != nil {
+ uc.logger.Printf("failed to handle messages: %v", err)
+ net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
+ }
+ uc.Close()
+ net.user.events <- eventUpstreamDisconnected{uc}
+
+ backoff.Reset()
+ }
+}
+
+func (net *network) stop() {
+ if !net.isStopped() {
+ close(net.stopped)
+ }
+
+ if net.conn != nil {
+ net.conn.Close()
+ }
+}
+
+func (net *network) detach(ch *Channel) {
+ if ch.Detached {
+ return
+ }
+
+ net.logger.Printf("detaching channel %q", ch.Name)
+
+ ch.Detached = true
+
+ if net.user.msgStore != nil {
+ nameCM := net.casemap(ch.Name)
+ lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
+ if err != nil {
+ net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
+ }
+ ch.DetachedInternalMsgID = lastID
+ }
+
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "PART",
+ Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
+ })
+ })
+}
+
+func (net *network) attach(ctx context.Context, ch *Channel) {
+ if !ch.Detached {
+ return
+ }
+
+ net.logger.Printf("attaching channel %q", ch.Name)
+
+ detachedMsgID := ch.DetachedInternalMsgID
+ ch.Detached = false
+ ch.DetachedInternalMsgID = ""
+
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+
+ net.conn.updateChannelAutoDetach(ch.Name)
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(net, ch.Name)},
+ })
+
+ if uch != nil {
+ forwardChannel(ctx, dc, uch)
+ }
+
+ if detachedMsgID != "" {
+ dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID)
+ }
+ })
+}
+
+func (net *network) deleteChannel(ctx context.Context, name string) error {
+ ch := net.channels.Value(name)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
+ return err
+ }
+ net.channels.Delete(name)
+ return nil
+}
+
+func (net *network) updateCasemapping(newCasemap casemapping) {
+ net.casemap = newCasemap
+ net.channels.SetCasemapping(newCasemap)
+ net.delivered.m.SetCasemapping(newCasemap)
+ if uc := net.conn; uc != nil {
+ uc.channels.SetCasemapping(newCasemap)
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.Members.SetCasemapping(newCasemap)
+ }
+ uc.monitored.SetCasemapping(newCasemap)
+ }
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.monitored.SetCasemapping(newCasemap)
+ })
+}
+
+func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) {
+ if !net.user.hasPersistentMsgStore() {
+ return
+ }
+
+ var receipts []DeliveryReceipt
+ net.delivered.ForEachTarget(func(target string) {
+ msgID := net.delivered.LoadID(target, clientName)
+ if msgID == "" {
+ return
+ }
+ receipts = append(receipts, DeliveryReceipt{
+ Target: target,
+ InternalMsgID: msgID,
+ })
+ })
+
+ if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil {
+ net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
+ }
+}
+
+func (net *network) isHighlight(msg *irc.Message) bool {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return false
+ }
+
+ text := msg.Params[1]
+
+ nick := net.Nick
+ if net.conn != nil {
+ nick = net.conn.nick
+ }
+
+ // TODO: use case-mapping aware comparison here
+ return msg.Prefix.Name != nick && isHighlight(text, nick)
+}
+
+func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
+ highlight := net.isHighlight(msg)
+ return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
+}
+
+func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
+ // User may have e.g. EXTERNAL mechanism configured. We do not want to
+ // automatically erase the key pair or any other credentials.
+ if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
+ return
+ }
+
+ net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
+ net.SASL.Mechanism = "PLAIN"
+ net.SASL.Plain.Username = username
+ net.SASL.Plain.Password = password
+ if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
+ net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
+ }
+}
+
+type user struct {
+ User
+ srv *Server
+ logger Logger
+
+ events chan event
+ done chan struct{}
+
+ networks []*network
+ downstreamConns []*downstreamConn
+ msgStore messageStore
+}
+
+func newUser(srv *Server, record *User) *user {
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
+
+ var msgStore messageStore
+ if logPath := srv.Config().LogPath; logPath != "" {
+ msgStore = newFSMessageStore(logPath, record)
+ } else {
+ msgStore = newMemoryMessageStore()
+ }
+
+ return &user{
+ User: *record,
+ srv: srv,
+ logger: logger,
+ events: make(chan event, 64),
+ done: make(chan struct{}),
+ msgStore: msgStore,
+ }
+}
+
+func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
+ for _, network := range u.networks {
+ if network.conn == nil {
+ continue
+ }
+ f(network.conn)
+ }
+}
+
+func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
+ for _, dc := range u.downstreamConns {
+ f(dc)
+ }
+}
+
+func (u *user) getNetwork(name string) *network {
+ for _, network := range u.networks {
+ if network.Addr == name {
+ return network
+ }
+ if network.Name != "" && network.Name == name {
+ return network
+ }
+ }
+ return nil
+}
+
+func (u *user) getNetworkByID(id int64) *network {
+ for _, net := range u.networks {
+ if net.ID == id {
+ return net
+ }
+ }
+ return nil
+}
+
+func (u *user) run() {
+ defer func() {
+ if u.msgStore != nil {
+ if err := u.msgStore.Close(); err != nil {
+ u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
+ }
+ }
+ close(u.done)
+ }()
+
+ networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
+ if err != nil {
+ u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
+ return
+ }
+
+ sort.Slice(networks, func(i, j int) bool {
+ return networks[i].ID < networks[j].ID
+ })
+
+ for _, record := range networks {
+ record := record
+ channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
+ continue
+ }
+
+ network := newNetwork(u, &record, channels)
+ u.networks = append(u.networks, network)
+
+ if u.hasPersistentMsgStore() {
+ receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
+ return
+ }
+
+ for _, rcpt := range receipts {
+ network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
+ }
+ }
+
+ go network.run()
+ }
+
+ for e := range u.events {
+ switch e := e.(type) {
+ case eventUpstreamConnected:
+ uc := e.uc
+
+ uc.network.conn = uc
+
+ uc.updateAway()
+ uc.updateMonitor()
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+ })
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=connected"},
+ })
+ }
+ })
+ uc.network.lastError = nil
+ case eventUpstreamDisconnected:
+ u.handleUpstreamDisconnected(e.uc)
+ case eventUpstreamConnectionError:
+ net := e.net
+
+ stopped := false
+ select {
+ case <-net.stopped:
+ stopped = true
+ default:
+ }
+
+ if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
+ net.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
+ })
+ }
+ net.lastError = e.err
+ case eventUpstreamError:
+ uc := e.uc
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
+ })
+ uc.network.lastError = e.err
+ case eventUpstreamMessage:
+ msg, uc := e.msg, e.uc
+ if uc.isClosed() {
+ uc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ if err := uc.handleMessage(context.TODO(), msg); err != nil {
+ uc.logger.Printf("failed to handle message %q: %v", msg, err)
+ }
+ case eventChannelDetach:
+ uc, name := e.uc, e.name
+ c := uc.network.channels.Value(name)
+ if c == nil || c.Detached {
+ continue
+ }
+ uc.network.detach(c)
+ if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
+ u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
+ }
+ case eventDownstreamConnected:
+ dc := e.dc
+
+ if dc.network != nil {
+ dc.monitored.SetCasemapping(dc.network.casemap)
+ }
+
+ if err := dc.welcome(context.TODO()); err != nil {
+ dc.logger.Printf("failed to handle new registered connection: %v", err)
+ break
+ }
+
+ u.downstreamConns = append(u.downstreamConns, dc)
+
+ dc.forEachNetwork(func(network *network) {
+ if network.lastError != nil {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
+ }
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.updateAway()
+ })
+ case eventDownstreamDisconnected:
+ dc := e.dc
+
+ for i := range u.downstreamConns {
+ if u.downstreamConns[i] == dc {
+ u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
+ break
+ }
+ }
+
+ dc.forEachNetwork(func(net *network) {
+ net.storeClientDeliveryReceipts(context.TODO(), dc.clientName)
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.cancelPendingCommandsByDownstreamID(dc.id)
+ uc.updateAway()
+ uc.updateMonitor()
+ })
+ case eventDownstreamMessage:
+ msg, dc := e.msg, e.dc
+ if dc.isClosed() {
+ dc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ err := dc.handleMessage(context.TODO(), msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ dc.logger.Printf("failed to handle message %q: %v", msg, err)
+ dc.Close()
+ }
+ case eventBroadcast:
+ msg := e.msg
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case eventUserUpdate:
+ // copy the user record because we'll mutate it
+ record := u.User
+
+ if e.password != nil {
+ record.Password = *e.password
+ }
+ if e.admin != nil {
+ record.Admin = *e.admin
+ }
+
+ e.done <- u.updateUser(context.TODO(), &record)
+
+ // If the password was updated, kill all downstream connections to
+ // force them to re-authenticate with the new credentials.
+ if e.password != nil {
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ }
+ case eventStop:
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ for _, n := range u.networks {
+ n.stop()
+
+ n.delivered.ForEachClient(func(clientName string) {
+ n.storeClientDeliveryReceipts(context.TODO(), clientName)
+ })
+ }
+ return
+ default:
+ panic(fmt.Sprintf("received unknown event type: %T", e))
+ }
+ }
+}
+
+func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
+ uc.network.conn = nil
+
+ uc.abortPendingCommands()
+
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.updateAutoDetach(0)
+ }
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+
+ // If the network has been removed, don't send a state change notification
+ found := false
+ for _, net := range u.networks {
+ if net == uc.network {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return
+ }
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=disconnected"},
+ })
+ }
+ })
+
+ if uc.network.lastError == nil {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
+ }
+ })
+ }
+}
+
+func (u *user) addNetwork(network *network) {
+ u.networks = append(u.networks, network)
+
+ sort.Slice(u.networks, func(i, j int) bool {
+ return u.networks[i].ID < u.networks[j].ID
+ })
+
+ go network.run()
+}
+
+func (u *user) removeNetwork(network *network) {
+ network.stop()
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.Close()
+ }
+ })
+
+ for i, net := range u.networks {
+ if net == network {
+ u.networks = append(u.networks[:i], u.networks[i+1:]...)
+ return
+ }
+ }
+
+ panic("tried to remove a non-existing network")
+}
+
+func (u *user) checkNetwork(record *Network) error {
+ url, err := record.URL()
+ if err != nil {
+ return err
+ }
+ if url.User != nil {
+ return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme)
+ }
+ if url.RawQuery != "" {
+ return fmt.Errorf("%v:// URL must not have query values", url.Scheme)
+ }
+ if url.Fragment != "" {
+ return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme)
+ }
+ switch url.Scheme {
+ case "ircs", "irc":
+ if url.Host == "" {
+ return fmt.Errorf("%v:// URL must have a host", url.Scheme)
+ }
+ if url.Path != "" {
+ return fmt.Errorf("%v:// URL must not have a path", url.Scheme)
+ }
+ case "irc+unix", "unix":
+ if url.Host != "" {
+ return fmt.Errorf("%v:// URL must not have a host", url.Scheme)
+ }
+ if url.Path == "" {
+ return fmt.Errorf("%v:// URL must have a path", url.Scheme)
+ }
+ default:
+ return fmt.Errorf("unknown URL scheme %q", url.Scheme)
+ }
+
+ if record.GetName() == "" {
+ return fmt.Errorf("network name cannot be empty")
+ }
+ if strings.HasPrefix(record.GetName(), "-") {
+ // Can be mixed up with flags when sending commands to the service
+ return fmt.Errorf("network name cannot start with a dash character")
+ }
+
+ for _, net := range u.networks {
+ if net.GetName() == record.GetName() && net.ID != record.ID {
+ return fmt.Errorf("a network with the name %q already exists", record.GetName())
+ }
+ }
+
+ return nil
+}
+
+func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID != 0 {
+ panic("tried creating an already-existing network")
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
+ return nil, fmt.Errorf("maximum number of networks reached")
+ }
+
+ network := newNetwork(u, record, nil)
+ err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
+ if err != nil {
+ return nil, err
+ }
+
+ u.addNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return network, nil
+}
+
+func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID == 0 {
+ panic("tried updating a new network")
+ }
+
+ // If the realname is reset to the default, just wipe the per-network
+ // setting
+ if record.Realname == u.Realname {
+ record.Realname = ""
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ network := u.getNetworkByID(record.ID)
+ if network == nil {
+ panic("tried updating a non-existing network")
+ }
+
+ if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
+ return nil, err
+ }
+
+ // Most network changes require us to re-connect to the upstream server
+
+ channels := make([]Channel, 0, network.channels.Len())
+ for _, entry := range network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, *ch)
+ }
+
+ updatedNetwork := newNetwork(u, record, channels)
+
+ // If we're currently connected, disconnect and perform the necessary
+ // bookkeeping
+ if network.conn != nil {
+ network.stop()
+ // Note: this will set network.conn to nil
+ u.handleUpstreamDisconnected(network.conn)
+ }
+
+ // Patch downstream connections to use our fresh updated network
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.network = updatedNetwork
+ }
+ })
+
+ // We need to remove the network after patching downstream connections,
+ // otherwise they'll get closed
+ u.removeNetwork(network)
+
+ // The filesystem message store needs to be notified whenever the network
+ // is renamed
+ fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
+ if isFS && updatedNetwork.GetName() != network.GetName() {
+ if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
+ network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
+ }
+ }
+
+ // This will re-connect to the upstream server
+ u.addNetwork(updatedNetwork)
+
+ // TODO: only broadcast attributes that have changed
+ idStr := fmt.Sprintf("%v", updatedNetwork.ID)
+ attrs := getNetworkAttrs(updatedNetwork)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return updatedNetwork, nil
+}
+
+func (u *user) deleteNetwork(ctx context.Context, id int64) error {
+ network := u.getNetworkByID(id)
+ if network == nil {
+ panic("tried deleting a non-existing network")
+ }
+
+ if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
+ return err
+ }
+
+ u.removeNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, "*"},
+ })
+ }
+ })
+
+ return nil
+}
+
+func (u *user) updateUser(ctx context.Context, record *User) error {
+ if u.ID != record.ID {
+ panic("ID mismatch when updating user")
+ }
+
+ realnameUpdated := u.Realname != record.Realname
+ if err := u.srv.db.StoreUser(ctx, record); err != nil {
+ return fmt.Errorf("failed to update user %q: %v", u.Username, err)
+ }
+ u.User = *record
+
+ if realnameUpdated {
+ // Re-connect to networks which use the default realname
+ var needUpdate []Network
+ for _, net := range u.networks {
+ if net.Realname == "" {
+ needUpdate = append(needUpdate, net.Network)
+ }
+ }
+
+ var netErr error
+ for _, net := range needUpdate {
+ if _, err := u.updateNetwork(ctx, &net); err != nil {
+ netErr = err
+ }
+ }
+ if netErr != nil {
+ return netErr
+ }
+ }
+
+ return nil
+}
+
+func (u *user) stop() {
+ u.events <- eventStop{}
+ <-u.done
+}
+
+func (u *user) hasPersistentMsgStore() bool {
+ if u.msgStore == nil {
+ return false
+ }
+ _, isMem := u.msgStore.(*memoryMessageStore)
+ return !isMem
+}
+
+// localAddrForHost returns the local address to use when connecting to host.
+// A nil address is returned when the OS should automatically pick one.
+func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
+ upstreamUserIPs := u.srv.Config().UpstreamUserIPs
+ if len(upstreamUserIPs) == 0 {
+ return nil, nil
+ }
+
+ ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
+ if err != nil {
+ return nil, err
+ }
+
+ wantIPv6 := false
+ for _, ip := range ips {
+ if ip.To4() == nil {
+ wantIPv6 = true
+ break
+ }
+ }
+
+ var ipNet *net.IPNet
+ for _, in := range upstreamUserIPs {
+ if wantIPv6 == (in.IP.To4() == nil) {
+ ipNet = in
+ break
+ }
+ }
+ if ipNet == nil {
+ return nil, nil
+ }
+
+ var ipInt big.Int
+ ipInt.SetBytes(ipNet.IP)
+ ipInt.Add(&ipInt, big.NewInt(u.ID+1))
+ ip := net.IP(ipInt.Bytes())
+ if !ipNet.Contains(ip) {
+ return nil, fmt.Errorf("IP network %v too small", ipNet)
+ }
+
+ return &net.TCPAddr{IP: ip}, nil
+}
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "runtime/debug"
+ "strings"
+)
+
+const (
+ defaultVersion = "0.0.0"
+ defaultCommit = "HEAD"
+ defaultBuild = "0000-01-01:00:00+00:00"
+)
+
+var (
+ // Version is the tagged release version in the form <major>.<minor>.<patch>
+ // following semantic versioning and is overwritten by the build system.
+ Version = defaultVersion
+
+ // Commit is the commit sha of the build (normally from Git) and is overwritten
+ // by the build system.
+ Commit = defaultCommit
+
+ // Build is the date and time of the build as an RFC3339 formatted string
+ // and is overwritten by the build system.
+ Build = defaultBuild
+)
+
+// FullVersion display the full version and build
+func FullVersion() string {
+ var sb strings.Builder
+
+ isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild
+
+ if !isDefault {
+ sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build))
+ }
+
+ if info, ok := debug.ReadBuildInfo(); ok {
+ if isDefault {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Version))
+ }
+ sb.WriteString(fmt.Sprintf(" %s", info.GoVersion))
+ if info.Main.Sum != "" {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum))
+ }
+ }
+
+ return sb.String()
+}
--- /dev/null
+vendor
+/suika
+/suikadb
+/suika-znc-import
+/suika.db
--- /dev/null
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+ <one line to give the program's name and a brief idea of what it does.>
+ Copyright (C) <year> <name of author>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+<https://www.gnu.org/licenses/>.
--- /dev/null
+GO ?= go
+RM ?= rm
+GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor
+PREFIX ?= /usr/local
+BINDIR ?= bin
+MANDIR ?= share/man
+MKDIR ?= mkdir
+CP ?= cp
+SYSCONFDIR ?= /etc
+ASCIIDOCTOR ?= asciidoctor
+
+VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"`
+COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"`
+BRANCH = `git rev-parse --abbrev-ref HEAD`
+BUILD = `git show -s --pretty=format:%cI`
+
+GOARCH ?= amd64
+GOOS ?= linux
+
+all: build
+
+build: vendor
+ ${GO} build ${GOFLAGS} ./cmd/suika
+ ${GO} build ${GOFLAGS} ./cmd/suikadb
+ ${GO} build ${GOFLAGS} ./cmd/suika-znc-import
+clean:
+ ${RM} -f suika suikadb suika-znc-import
+install:
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR}
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7
+ ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika
+ ${MKDIR} -p ${DESTDIR}/var/lib/suika
+ ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR}
+ ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1
+ ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config
+test:
+ go test
+vendor:
+ go mod vendor
+.PHONY: build clean install
--- /dev/null
+# suika
+
+[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika)
+
+A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power)
+
+- Multi-user
+- Support multiple clients for a single user, with proper backlog
+ synchronization
+- Support connecting to multiple upstream servers via a single IRC connection
+ to the bouncer
+
+## Building and installing
+
+Dependencies:
+
+- Go
+- BSD or GNU make
+
+For end users, a `Makefile` is provided:
+
+ make
+ doas make install
+
+For development, you can use `go run ./cmd/suika` as usual.
+
+## License
+AGPLv3, see [LICENSE](LICENSE).
+
+* Copyright (C) 2020 The soju Contributors
+* Copyright (C) 2023-present Izuru Yakumo
+
+The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gopkg.in/irc.v3"
+)
+
+func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) {
+ if !ch.complete {
+ panic("Tried to forward a partial channel")
+ }
+
+ // RPL_NOTOPIC shouldn't be sent on JOIN
+ if ch.Topic != "" {
+ sendTopic(dc, ch)
+ }
+
+ if dc.caps["soju.im/read"] {
+ channelCM := ch.conn.network.casemap(ch.Name)
+ r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err)
+ } else {
+ timestampStr := "*"
+ if r != nil {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr},
+ })
+ }
+ }
+
+ sendNames(dc, ch)
+}
+
+func sendTopic(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ if ch.Topic != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_TOPIC,
+ Params: []string{dc.nick, downstreamName, ch.Topic},
+ })
+ if ch.TopicWho != nil {
+ topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho)
+ topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime},
+ })
+ }
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NOTOPIC,
+ Params: []string{dc.nick, downstreamName, "No topic is set"},
+ })
+ }
+}
+
+func sendNames(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ emptyNameReply := &irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, ""},
+ }
+ maxLength := maxMessageLength - len(emptyNameReply.String())
+
+ var buf strings.Builder
+ for _, entry := range ch.Members.innerMap {
+ nick := entry.originalKey
+ memberships := entry.value.(*memberships)
+ s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick)
+
+ n := buf.Len() + 1 + len(s)
+ if buf.Len() != 0 && n > maxLength {
+ // There's not enough space for the next space + nick.
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ buf.Reset()
+ }
+
+ if buf.Len() != 0 {
+ buf.WriteByte(' ')
+ }
+ buf.WriteString(s)
+ }
+
+ if buf.Len() != 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, downstreamName, "End of /NAMES list"},
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/ed25519"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "math/big"
+ "time"
+)
+
+func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) {
+ var (
+ privKey crypto.PrivateKey
+ pubKey crypto.PublicKey
+ )
+ switch keyType {
+ case "rsa":
+ key, err := rsa.GenerateKey(rand.Reader, bits)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ecdsa":
+ key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ed25519":
+ var err error
+ pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ // Using PKCS#8 allows easier extension for new key types.
+ privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ notBefore := time.Now()
+ // Lets make a fair assumption nobody will use the same cert for more than 20 years...
+ notAfter := notBefore.Add(24 * time.Hour * 365 * 20)
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ return nil, nil, err
+ }
+ cert := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{CommonName: "suika auto-generated certificate"},
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
+ certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return privKeyBytes, certBytes, nil
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/url"
+ "os"
+ "strings"
+ "unicode"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+const usage = `usage: suika-znc-import [options...] <znc config path>
+
+Imports configuration from a ZNC file. Users and networks are merged if they
+already exist in the suika database. ZNC settings overwrite existing suika
+settings.
+
+Options:
+
+ -help Show this help message
+ -config <path> Path to suika config file
+ -user <username> Limit import to username (may be specified multiple times)
+ -network <name> Limit import to network (may be specified multiple times)
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ users := make(map[string]bool)
+ networks := make(map[string]bool)
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Var((*stringSetFlag)(&users), "user", "")
+ flag.Var((*stringSetFlag)(&networks), "network", "")
+ flag.Parse()
+
+ zncPath := flag.Arg(0)
+ if zncPath == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ ctx := context.Background()
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+ defer db.Close()
+
+ f, err := os.Open(zncPath)
+ if err != nil {
+ log.Fatalf("failed to open ZNC configuration file: %v", err)
+ }
+ defer f.Close()
+
+ zp := zncParser{bufio.NewReader(f), 1}
+ root, err := zp.sectionBody("", "")
+ if err != nil {
+ log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
+ }
+
+ l, err := db.ListUsers(ctx)
+ if err != nil {
+ log.Fatalf("failed to list users in DB: %v", err)
+ }
+ existingUsers := make(map[string]*suika.User, len(l))
+ for i, u := range l {
+ existingUsers[u.Username] = &l[i]
+ }
+
+ usersCreated := 0
+ usersImported := 0
+ networksImported := 0
+ channelsImported := 0
+ root.ForEach("User", func(section *zncSection) {
+ username := section.Name
+ if len(users) > 0 && !users[username] {
+ return
+ }
+ usersImported++
+
+ u, ok := existingUsers[username]
+ if ok {
+ log.Printf("user %q: updating existing user", username)
+ } else {
+ // "!!" is an invalid crypt format, thus disables password auth
+ u = &suika.User{Username: username, Password: "!!"}
+ usersCreated++
+ log.Printf("user %q: creating new user", username)
+ }
+
+ u.Admin = section.Values.Get("Admin") == "true"
+
+ if err := db.StoreUser(ctx, u); err != nil {
+ log.Fatalf("failed to store user %q: %v", username, err)
+ }
+ userID := u.ID
+
+ l, err := db.ListNetworks(ctx, userID)
+ if err != nil {
+ log.Fatalf("failed to list networks for user %q: %v", username, err)
+ }
+ existingNetworks := make(map[string]*suika.Network, len(l))
+ for i, n := range l {
+ existingNetworks[n.GetName()] = &l[i]
+ }
+
+ nick := section.Values.Get("Nick")
+ realname := section.Values.Get("RealName")
+ ident := section.Values.Get("Ident")
+
+ section.ForEach("Network", func(section *zncSection) {
+ netName := section.Name
+ if len(networks) > 0 && !networks[netName] {
+ return
+ }
+ networksImported++
+
+ logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
+ logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
+
+ netNick := section.Values.Get("Nick")
+ if netNick == "" {
+ netNick = nick
+ }
+ netRealname := section.Values.Get("RealName")
+ if netRealname == "" {
+ netRealname = realname
+ }
+ netIdent := section.Values.Get("Ident")
+ if netIdent == "" {
+ netIdent = ident
+ }
+
+ for _, name := range section.Values["LoadModule"] {
+ switch name {
+ case "sasl":
+ logger.Printf("warning: SASL credentials not imported")
+ case "nickserv":
+ logger.Printf("warning: NickServ credentials not imported")
+ case "perform":
+ logger.Printf("warning: \"perform\" plugin commands not imported")
+ }
+ }
+
+ u, pass, err := importNetworkServer(section.Values.Get("Server"))
+ if err != nil {
+ logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
+ }
+
+ n, ok := existingNetworks[netName]
+ if ok {
+ logger.Printf("updating existing network")
+ } else {
+ n = &suika.Network{Name: netName}
+ logger.Printf("creating new network")
+ }
+
+ n.Addr = u.String()
+ n.Nick = netNick
+ n.Username = netIdent
+ n.Realname = netRealname
+ n.Pass = pass
+ n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
+
+ if err := db.StoreNetwork(ctx, userID, n); err != nil {
+ logger.Fatalf("failed to store network: %v", err)
+ }
+
+ l, err := db.ListChannels(ctx, n.ID)
+ if err != nil {
+ logger.Fatalf("failed to list channels: %v", err)
+ }
+ existingChannels := make(map[string]*suika.Channel, len(l))
+ for i, ch := range l {
+ existingChannels[ch.Name] = &l[i]
+ }
+
+ section.ForEach("Chan", func(section *zncSection) {
+ chName := section.Name
+
+ if section.Values.Get("Disabled") == "true" {
+ logger.Printf("skipping import of disabled channel %q", chName)
+ return
+ }
+
+ channelsImported++
+
+ ch, ok := existingChannels[chName]
+ if ok {
+ logger.Printf("channel %q: updating existing channel", chName)
+ } else {
+ ch = &suika.Channel{Name: chName}
+ logger.Printf("channel %q: creating new channel", chName)
+ }
+
+ ch.Key = section.Values.Get("Key")
+ ch.Detached = section.Values.Get("Detached") == "true"
+
+ if err := db.StoreChannel(ctx, n.ID, ch); err != nil {
+ logger.Printf("channel %q: failed to store channel: %v", chName, err)
+ }
+ })
+ })
+ })
+
+ if err := db.Close(); err != nil {
+ log.Printf("failed to close database: %v", err)
+ }
+
+ if usersCreated > 0 {
+ log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password <username>`")
+ }
+
+ log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
+}
+
+func importNetworkServer(s string) (u *url.URL, pass string, err error) {
+ parts := strings.Fields(s)
+ if len(parts) < 2 {
+ return nil, "", fmt.Errorf("expected space-separated host and port")
+ }
+
+ scheme := "irc"
+ host := parts[0]
+ port := parts[1]
+ if strings.HasPrefix(port, "+") {
+ port = port[1:]
+ scheme = "ircs"
+ }
+
+ if len(parts) > 2 {
+ pass = parts[2]
+ }
+
+ u = &url.URL{
+ Scheme: scheme,
+ Host: host + ":" + port,
+ }
+ return u, pass, nil
+}
+
+type zncSection struct {
+ Type string
+ Name string
+ Values zncValues
+ Children []zncSection
+}
+
+func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
+ for _, section := range s.Children {
+ if section.Type == typ {
+ f(§ion)
+ }
+ }
+}
+
+type zncValues map[string][]string
+
+func (zv zncValues) Get(k string) string {
+ if len(zv[k]) == 0 {
+ return ""
+ }
+ return zv[k][0]
+}
+
+type zncParser struct {
+ br *bufio.Reader
+ line int
+}
+
+func (zp *zncParser) readByte() (byte, error) {
+ b, err := zp.br.ReadByte()
+ if b == '\n' {
+ zp.line++
+ }
+ return b, err
+}
+
+func (zp *zncParser) readRune() (rune, int, error) {
+ r, n, err := zp.br.ReadRune()
+ if r == '\n' {
+ zp.line++
+ }
+ return r, n, err
+}
+
+func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
+ section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
+
+Loop:
+ for {
+ if err := zp.skipSpace(); err != nil {
+ return nil, err
+ }
+
+ b, err := zp.br.Peek(2)
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ switch b[0] {
+ case '<':
+ if b[1] == '/' {
+ break Loop
+ } else {
+ childType, childName, err := zp.sectionHeader()
+ if err != nil {
+ return nil, err
+ }
+ child, err := zp.sectionBody(childType, childName)
+ if err != nil {
+ return nil, err
+ }
+ if footerType, err := zp.sectionFooter(); err != nil {
+ return nil, err
+ } else if footerType != childType {
+ return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
+ }
+ section.Children = append(section.Children, *child)
+ }
+ case '/':
+ if b[1] == '/' {
+ if err := zp.skipComment(); err != nil {
+ return nil, err
+ }
+ break
+ }
+ fallthrough
+ default:
+ k, v, err := zp.keyValuePair()
+ if err != nil {
+ return nil, err
+ }
+ section.Values[k] = append(section.Values[k], v)
+ }
+ }
+
+ return section, nil
+}
+
+func (zp *zncParser) skipSpace() error {
+ for {
+ r, _, err := zp.readRune()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if !unicode.IsSpace(r) {
+ zp.br.UnreadRune()
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) skipComment() error {
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+
+ for {
+ b, err := zp.readByte()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if b == '\n' {
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) sectionHeader() (string, string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", "", err
+ }
+ typ, err := zp.readWord(' ')
+ if err != nil {
+ return "", "", err
+ }
+ name, err := zp.readWord('>')
+ return typ, name, err
+}
+
+func (zp *zncParser) sectionFooter() (string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return "", err
+ }
+ return zp.readWord('>')
+}
+
+func (zp *zncParser) keyValuePair() (string, string, error) {
+ k, err := zp.readWord('=')
+ if err != nil {
+ return "", "", err
+ }
+ v, err := zp.readWord('\n')
+ return strings.TrimSpace(k), strings.TrimSpace(v), err
+}
+
+func (zp *zncParser) expectRune(expected rune) error {
+ r, _, err := zp.readRune()
+ if err != nil {
+ return err
+ } else if r != expected {
+ return fmt.Errorf("expected %q, got %q", expected, r)
+ }
+ return nil
+}
+
+func (zp *zncParser) readWord(delim byte) (string, error) {
+ var sb strings.Builder
+ for {
+ b, err := zp.readByte()
+ if err != nil {
+ return "", err
+ }
+
+ if b == delim {
+ return sb.String(), nil
+ }
+ if b == '\n' {
+ return "", fmt.Errorf("expected %q before newline", delim)
+ }
+
+ sb.WriteByte(b)
+ }
+}
+
+type stringSetFlag map[string]bool
+
+func (v *stringSetFlag) String() string {
+ return fmt.Sprint(map[string]bool(*v))
+}
+
+func (v *stringSetFlag) Set(s string) error {
+ (*v)[s] = true
+ return nil
+}
--- /dev/null
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/url"
+ "os"
+ "os/signal"
+ "strings"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+// TCP keep-alive interval for downstream TCP connections
+const downstreamKeepAlive = 1 * time.Hour
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+func bumpOpenedFileLimit() error {
+ var rlimit syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
+ }
+ rlimit.Cur = rlimit.Max
+ if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
+ }
+ return nil
+}
+
+var (
+ configPath string
+ debug bool
+
+ tlsCert atomic.Value // *tls.Certificate
+)
+
+func loadConfig() (*config.Server, *suika.Config, error) {
+ var raw *config.Server
+ if configPath != "" {
+ var err error
+ raw, err = config.Load(configPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load config file: %v", err)
+ }
+ } else {
+ raw = config.Defaults()
+ }
+
+ var motd string
+ if raw.MOTDPath != "" {
+ b, err := os.ReadFile(raw.MOTDPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
+ }
+ motd = strings.TrimSuffix(string(b), "\n")
+ }
+
+ if raw.TLS != nil {
+ cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
+ }
+ tlsCert.Store(&cert)
+ }
+
+ cfg := &suika.Config{
+ Hostname: raw.Hostname,
+ Title: raw.Title,
+ LogPath: raw.LogPath,
+ MaxUserNetworks: raw.MaxUserNetworks,
+ MultiUpstream: raw.MultiUpstream,
+ UpstreamUserIPs: raw.UpstreamUserIPs,
+ MOTD: motd,
+ }
+ return raw, cfg, nil
+}
+
+func main() {
+ var listen []string
+ flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.BoolVar(&debug, "debug", false, "enable debug logging")
+ flag.Parse()
+
+ cfg, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ cfg.Listen = append(cfg.Listen, listen...)
+ if len(cfg.Listen) == 0 {
+ cfg.Listen = []string{":6667"}
+ }
+
+ if err := bumpOpenedFileLimit(); err != nil {
+ log.Printf("failed to bump max number of opened files: %v", err)
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ var tlsCfg *tls.Config
+ if cfg.TLS != nil {
+ tlsCfg = &tls.Config{
+ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return tlsCert.Load().(*tls.Certificate), nil
+ },
+ }
+ }
+
+ srv := suika.NewServer(db)
+ srv.SetConfig(serverCfg)
+ srv.Logger = suika.NewLogger(log.Writer(), debug)
+
+ for _, listen := range cfg.Listen {
+ listen := listen // copy
+ listenURI := listen
+ if !strings.Contains(listenURI, ":/") {
+ // This is a raw domain name, make it an URL with an empty scheme
+ listenURI = "//" + listenURI
+ }
+ u, err := url.Parse(listenURI)
+ if err != nil {
+ log.Fatalf("failed to parse listen URI %q: %v", listen, err)
+ }
+
+ switch u.Scheme {
+ case "ircs", "":
+ if tlsCfg == nil {
+ log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
+ }
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6697"
+ }
+ ircsTLSCfg := tlsCfg.Clone()
+ ircsTLSCfg.NextProtos = []string{"irc"}
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ l, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
+ }
+ ln := tls.NewListener(l, ircsTLSCfg)
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "irc":
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6667"
+ }
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ ln, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "unix":
+ ln, err := net.Listen("unix", u.Path)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ default:
+ log.Fatalf("failed to listen on %q: unsupported scheme", listen)
+ }
+
+ log.Printf("starting suika version %v\n", suika.FullVersion())
+ log.Printf("server listening on %q", listen)
+ }
+
+ sigCh := make(chan os.Signal, 1)
+ signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
+
+ if err := srv.Start(); err != nil {
+ log.Fatal(err)
+ }
+
+ for sig := range sigCh {
+ switch sig {
+ case syscall.SIGHUP:
+ log.Print("reloading configuration")
+ _, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Printf("failed to reloading configuration: %v", err)
+ } else {
+ srv.SetConfig(serverCfg)
+ }
+ case syscall.SIGINT, syscall.SIGTERM:
+ log.Print("shutting down server")
+ srv.Shutdown()
+ return
+ }
+ }
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+ "golang.org/x/crypto/bcrypt"
+ "golang.org/x/term"
+)
+
+const usage = `usage: suikadb [-config path] <action> [options...]
+
+ create-user <username> [-admin] Create a new user
+ change-password <username> Change password for a user
+ help Show this help message
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Parse()
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ ctx := context.Background()
+
+ switch cmd := flag.Arg(0); cmd {
+ case "create-user":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ fs := flag.NewFlagSet("", flag.ExitOnError)
+ admin := fs.Bool("admin", false, "make the new user admin")
+ fs.Parse(flag.Args()[2:])
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user := suika.User{
+ Username: username,
+ Password: string(hashed),
+ Admin: *admin,
+ }
+ if err := db.StoreUser(ctx, &user); err != nil {
+ log.Fatalf("failed to create user: %v", err)
+ }
+ case "change-password":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ user, err := db.GetUser(ctx, username)
+ if err != nil {
+ log.Fatalf("failed to get user: %v", err)
+ }
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user.Password = string(hashed)
+ if err := db.StoreUser(ctx, user); err != nil {
+ log.Fatalf("failed to update password: %v", err)
+ }
+ case "version":
+ fmt.Printf("%v\n", suika.FullVersion())
+ default:
+ flag.Usage()
+ if cmd != "help" {
+ os.Exit(1)
+ }
+ }
+}
+
+func readPassword() ([]byte, error) {
+ var password []byte
+ var err error
+ fd := int(os.Stdin.Fd())
+
+ if term.IsTerminal(fd) {
+ fmt.Printf("Password: ")
+ password, err = term.ReadPassword(int(os.Stdin.Fd()))
+ if err != nil {
+ return nil, err
+ }
+ fmt.Printf("\n")
+ } else {
+ fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n")
+ // TODO: the buffering messes up repeated calls to readPassword
+ scanner := bufio.NewScanner(os.Stdin)
+ if !scanner.Scan() {
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ return nil, io.ErrUnexpectedEOF
+ }
+ password = scanner.Bytes()
+
+ if len(password) == 0 {
+ return nil, fmt.Errorf("zero length password")
+ }
+ }
+
+ return password, nil
+}
--- /dev/null
+db sqlite3 /var/lib/suika/main.db
+log fs /var/lib/suika/logs/
--- /dev/null
+package config
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+
+ "git.sr.ht/~emersion/go-scfg"
+)
+
+type TLS struct {
+ CertPath, KeyPath string
+}
+
+type Server struct {
+ Listen []string
+ TLS *TLS
+ Hostname string
+ Title string
+ MOTDPath string
+
+ SQLDriver string
+ SQLSource string
+ LogPath string
+
+ MaxUserNetworks int
+ MultiUpstream bool
+ UpstreamUserIPs []*net.IPNet
+}
+
+func Defaults() *Server {
+ hostname, err := os.Hostname()
+ if err != nil {
+ hostname = "localhost"
+ }
+ return &Server{
+ Hostname: hostname,
+ SQLDriver: "sqlite3",
+ SQLSource: "suika.db",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ }
+}
+
+func Load(path string) (*Server, error) {
+ cfg, err := scfg.Load(path)
+ if err != nil {
+ return nil, err
+ }
+ return parse(cfg)
+}
+
+func parse(cfg scfg.Block) (*Server, error) {
+ srv := Defaults()
+ for _, d := range cfg {
+ switch d.Name {
+ case "listen":
+ var uri string
+ if err := d.ParseParams(&uri); err != nil {
+ return nil, err
+ }
+ srv.Listen = append(srv.Listen, uri)
+ case "hostname":
+ if err := d.ParseParams(&srv.Hostname); err != nil {
+ return nil, err
+ }
+ case "title":
+ if err := d.ParseParams(&srv.Title); err != nil {
+ return nil, err
+ }
+ case "motd":
+ if err := d.ParseParams(&srv.MOTDPath); err != nil {
+ return nil, err
+ }
+ case "tls":
+ tls := &TLS{}
+ if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
+ return nil, err
+ }
+ srv.TLS = tls
+ case "db":
+ if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
+ return nil, err
+ }
+ case "log":
+ var driver string
+ if err := d.ParseParams(&driver, &srv.LogPath); err != nil {
+ return nil, err
+ }
+ if driver != "fs" {
+ return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver)
+ }
+ case "max-user-networks":
+ var max string
+ if err := d.ParseParams(&max); err != nil {
+ return nil, err
+ }
+ var err error
+ if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ case "multi-upstream-mode":
+ var str string
+ if err := d.ParseParams(&str); err != nil {
+ return nil, err
+ }
+ v, err := strconv.ParseBool(str)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ srv.MultiUpstream = v
+ case "upstream-user-ip":
+ if len(srv.UpstreamUserIPs) > 0 {
+ return nil, fmt.Errorf("directive %q: can only be specified once", d.Name)
+ }
+ var hasIPv4, hasIPv6 bool
+ for _, s := range d.Params {
+ _, n, err := net.ParseCIDR(s)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
+ }
+ if n.IP.To4() == nil {
+ if hasIPv6 {
+ return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name)
+ }
+ hasIPv6 = true
+ } else {
+ if hasIPv4 {
+ return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name)
+ }
+ hasIPv4 = true
+ }
+ srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
+ }
+ default:
+ return nil, fmt.Errorf("unknown directive %q", d.Name)
+ }
+ }
+
+ return srv, nil
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/time/rate"
+ "gopkg.in/irc.v3"
+)
+
+// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
+// reading and writing IRC messages.
+type ircConn interface {
+ ReadMessage() (*irc.Message, error)
+ WriteMessage(*irc.Message) error
+ Close() error
+ SetReadDeadline(time.Time) error
+ SetWriteDeadline(time.Time) error
+ RemoteAddr() net.Addr
+ LocalAddr() net.Addr
+}
+
+func newNetIRCConn(c net.Conn) ircConn {
+ type netConn net.Conn
+ return struct {
+ *irc.Conn
+ netConn
+ }{irc.NewConn(c), c}
+}
+
+type connOptions struct {
+ Logger Logger
+ RateLimitDelay time.Duration
+ RateLimitBurst int
+}
+
+type conn struct {
+ conn ircConn
+ srv *Server
+ logger Logger
+
+ lock sync.Mutex
+ outgoing chan<- *irc.Message
+ closed bool
+ closedCh chan struct{}
+}
+
+func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
+ outgoing := make(chan *irc.Message, 64)
+ c := &conn{
+ conn: ic,
+ srv: srv,
+ outgoing: outgoing,
+ logger: options.Logger,
+ closedCh: make(chan struct{}),
+ }
+
+ go func() {
+ ctx, cancel := c.NewContext(context.Background())
+ defer cancel()
+
+ rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
+ for msg := range outgoing {
+ if err := rl.Wait(ctx); err != nil {
+ break
+ }
+
+ c.logger.Debugf("sent: %v", msg)
+ c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
+ if err := c.conn.WriteMessage(msg); err != nil {
+ c.logger.Printf("failed to write message: %v", err)
+ break
+ }
+ }
+ if err := c.conn.Close(); err != nil && !isErrClosed(err) {
+ c.logger.Printf("failed to close connection: %v", err)
+ } else {
+ c.logger.Debugf("connection closed")
+ }
+ // Drain the outgoing channel to prevent SendMessage from blocking
+ for range outgoing {
+ // This space is intentionally left blank
+ }
+ }()
+
+ c.logger.Debugf("new connection")
+ return c
+}
+
+func (c *conn) isClosed() bool {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ return c.closed
+}
+
+// Close closes the connection. It is safe to call from any goroutine.
+func (c *conn) Close() error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return fmt.Errorf("connection already closed")
+ }
+
+ err := c.conn.Close()
+ c.closed = true
+ close(c.outgoing)
+ close(c.closedCh)
+ return err
+}
+
+func (c *conn) ReadMessage() (*irc.Message, error) {
+ msg, err := c.conn.ReadMessage()
+ if isErrClosed(err) {
+ return nil, io.EOF
+ } else if err != nil {
+ return nil, err
+ }
+
+ c.logger.Debugf("received: %v", msg)
+ return msg, nil
+}
+
+// SendMessage queues a new outgoing message. It is safe to call from any
+// goroutine.
+//
+// If the connection is closed before the message is sent, SendMessage silently
+// drops the message.
+func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return
+ }
+
+ select {
+ case c.outgoing <- msg:
+ // Success
+ case <-ctx.Done():
+ c.logger.Printf("failed to send message: %v", ctx.Err())
+ }
+}
+
+func (c *conn) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+
+func (c *conn) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+// NewContext returns a copy of the parent context with a new Done channel. The
+// returned context's Done channel is closed when the connection is closed,
+// when the returned cancel function is called, or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// Canceling this context releases resources associated with it, so code should
+// call cancel as soon as the operations running in this Context complete.
+func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(parent)
+
+ go func() {
+ defer cancel()
+
+ select {
+ case <-ctx.Done():
+ // The parent context has been cancelled, or the caller has called
+ // cancel()
+ case <-c.closedCh:
+ // The connection has been closed
+ }
+ }()
+
+ return ctx, cancel
+}
--- /dev/null
+#!/bin/sh -eu
+
+# Converts a log dir to its case-mapped form.
+#
+# suika needs to be stopped for this script to work properly. The script may
+# re-order messages that happened within the same second interval if merging
+# two daily log files is necessary.
+#
+# usage: casemap-logs.sh <directory>
+
+root="$1"
+
+for net_dir in "$root"/*/*; do
+ for chan in $(ls "$net_dir"); do
+ cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')"
+ if [ "$chan" = "$cm_chan" ]; then
+ continue
+ fi
+
+ if ! [ -d "$net_dir/$cm_chan" ]; then
+ echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ mv "$net_dir/$chan" "$net_dir/$cm_chan"
+ continue
+ fi
+
+ echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ for day in $(ls "$net_dir/$chan"); do
+ if ! [ -e "$net_dir/$cm_chan/$day" ]; then
+ echo >&2 " Moving log file: '$day'"
+ mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day"
+ continue
+ fi
+
+ echo >&2 " Merging log file: '$day'"
+ sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new"
+ mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day"
+ rm "$net_dir/$chan/$day"
+ done
+
+ rmdir "$net_dir/$chan"
+ done
+done
--- /dev/null
+# Clients
+
+This page describes how to configure IRC clients to better integrate with soju.
+
+Also see the [IRCv3 support tables] for a more general list of clients.
+
+# catgirl
+
+catgirl doesn't properly implement cap-3.2, so many capabilities will be
+disabled. catgirl developers have publicly stated that supporting bouncers such
+as soju is a non-goal.
+
+# [Emacs]
+
+There are two clients provided with Emacs. They require some setup to work
+properly.
+
+## Erc
+
+You need to explicitly set the username, which is the defcustom
+`erc-email-userid`.
+
+```elisp
+(setq erc-email-userid "<username>/irc.libera.chat") ;; Example with Libera.Chat
+(defun run-erc ()
+ (interactive)
+ (erc-tls :server "<server>"
+ :port 6697
+ :nick "<nick>"
+ :password "<password>"))
+```
+
+Then run `M-x run-erc`.
+
+## Rcirc
+
+The only thing needed here is the general config:
+
+```elisp
+(setq rcirc-server-alist
+ '(("<server>"
+ :port 6697
+ :encryption tls
+ :nick "<nick>"
+ :user-name "<username>/irc.libera.chat" ;; Example with Libera.Chat
+ :password "<password>")))
+```
+
+Then run `M-x irc`.
+
+# [gamja]
+
+gamja has been designed together with soju, so should have excellent
+integration. gamja supports many IRCv3 features including chat history.
+gamja also provides UI to manage soju networks via the
+`soju.im/bouncer-networks` extension.
+
+# [goguma]
+
+Much like gamja, goguma has been designed together with soju, so should have
+excellent integration. goguma supports many IRCv3 features including chat
+history. goguma should seamlessly connect to all networks configured in soju via
+the `soju.im/bouncer-networks` extension.
+
+# [Hexchat]
+
+Hexchat has support for a small set of IRCv3 capabilities. To prevent
+automatically reconnecting to channels parted from soju, and prevent buffering
+outgoing messages:
+
+ /set irc_reconnect_rejoin off
+ /set net_throttle off
+
+# [senpai]
+
+senpai is being developed with soju in mind, so should have excellent
+integration. senpai supports many IRCv3 features including chat history.
+
+# [Weechat]
+
+A [Weechat script] is available to provide better integration with soju.
+The script will automatically connect to all of your networks once a
+single connection to soju is set up in Weechat.
+
+On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them:
+
+ /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names
+ /save
+ /reconnect -all
+
+See `/help cap` for more information.
+
+[IRCv3 support tables]: https://ircv3.net/software/clients
+[gamja]: https://sr.ht/~emersion/gamja/
+[goguma]: https://sr.ht/~emersion/goguma/
+[senpai]: https://sr.ht/~taiite/senpai/
+[Weechat]: https://weechat.org/
+[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py
+[Hexchat]: https://hexchat.github.io/
+[Emacs]: https://www.gnu.org/software/emacs/
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+ "time"
+)
+
+type Database interface {
+ Close() error
+ Stats(ctx context.Context) (*DatabaseStats, error)
+
+ ListUsers(ctx context.Context) ([]User, error)
+ GetUser(ctx context.Context, username string) (*User, error)
+ StoreUser(ctx context.Context, user *User) error
+ DeleteUser(ctx context.Context, id int64) error
+
+ ListNetworks(ctx context.Context, userID int64) ([]Network, error)
+ StoreNetwork(ctx context.Context, userID int64, network *Network) error
+ DeleteNetwork(ctx context.Context, id int64) error
+ ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
+ StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
+ DeleteChannel(ctx context.Context, id int64) error
+
+ ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
+ StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
+
+ GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error)
+ StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error
+}
+
+func OpenDB(driver, source string) (Database, error) {
+ switch driver {
+ case "sqlite3":
+ return OpenSqliteDB(source)
+ case "postgres":
+ return OpenPostgresDB(source)
+ default:
+ return nil, fmt.Errorf("unsupported database driver: %q", driver)
+ }
+}
+
+type DatabaseStats struct {
+ Users int64
+ Networks int64
+ Channels int64
+}
+
+type User struct {
+ ID int64
+ Username string
+ Password string // hashed
+ Realname string
+ Admin bool
+}
+
+type SASL struct {
+ Mechanism string
+
+ Plain struct {
+ Username string
+ Password string
+ }
+
+ // TLS client certificate authentication.
+ External struct {
+ // X.509 certificate in DER form.
+ CertBlob []byte
+ // PKCS#8 private key in DER form.
+ PrivKeyBlob []byte
+ }
+}
+
+type Network struct {
+ ID int64
+ Name string
+ Addr string
+ Nick string
+ Username string
+ Realname string
+ Pass string
+ ConnectCommands []string
+ SASL SASL
+ Enabled bool
+}
+
+func (net *Network) GetName() string {
+ if net.Name != "" {
+ return net.Name
+ }
+ return net.Addr
+}
+
+func (net *Network) URL() (*url.URL, error) {
+ s := net.Addr
+ if !strings.Contains(s, "://") {
+ // This is a raw domain name, make it an URL with the default scheme
+ s = "ircs://" + s
+ }
+
+ u, err := url.Parse(s)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
+ }
+
+ return u, nil
+}
+
+func GetNick(user *User, net *Network) string {
+ if net.Nick != "" {
+ return net.Nick
+ }
+ return user.Username
+}
+
+func GetUsername(user *User, net *Network) string {
+ if net.Username != "" {
+ return net.Username
+ }
+ return GetNick(user, net)
+}
+
+func GetRealname(user *User, net *Network) string {
+ if net.Realname != "" {
+ return net.Realname
+ }
+ if user.Realname != "" {
+ return user.Realname
+ }
+ return GetNick(user, net)
+}
+
+type MessageFilter int
+
+const (
+ // TODO: use customizable user defaults for FilterDefault
+ FilterDefault MessageFilter = iota
+ FilterNone
+ FilterHighlight
+ FilterMessage
+)
+
+func parseFilter(filter string) (MessageFilter, error) {
+ switch filter {
+ case "default":
+ return FilterDefault, nil
+ case "none":
+ return FilterNone, nil
+ case "highlight":
+ return FilterHighlight, nil
+ case "message":
+ return FilterMessage, nil
+ }
+ return 0, fmt.Errorf("unknown filter: %q", filter)
+}
+
+type Channel struct {
+ ID int64
+ Name string
+ Key string
+
+ Detached bool
+ DetachedInternalMsgID string
+
+ RelayDetached MessageFilter
+ ReattachOn MessageFilter
+ DetachAfter time.Duration
+ DetachOn MessageFilter
+}
+
+type DeliveryReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Client string
+ InternalMsgID string
+}
+
+type ReadReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Timestamp time.Time
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "errors"
+ "fmt"
+ "math"
+ "strings"
+ "time"
+
+ _ "github.com/lib/pq"
+)
+
+const postgresQueryTimeout = 5 * time.Second
+
+const postgresConfigSchema = `
+CREATE TABLE IF NOT EXISTS "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+`
+//go:embed suika_psql_schema.sql
+var postgresSchema string
+
+var postgresMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
+ `
+ CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+ ALTER TABLE "Network"
+ ALTER COLUMN sasl_mechanism
+ TYPE sasl_mechanism
+ USING sasl_mechanism::sasl_mechanism;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type PostgresDB struct {
+ db *sql.DB
+}
+
+func OpenPostgresDB(source string) (Database, error) {
+ sqlPostgresDB, err := sql.Open("postgres", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &PostgresDB{db: sqlPostgresDB}
+ if err := db.upgrade(); err != nil {
+ sqlPostgresDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *PostgresDB) upgrade() error {
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if _, err := tx.Exec(postgresConfigSchema); err != nil {
+ return fmt.Errorf("failed to create Config table: %s", err)
+ }
+
+ var version int
+ err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return fmt.Errorf("failed to query schema version: %s", err)
+ }
+
+ if version == len(postgresMigrations) {
+ return nil
+ }
+ if version > len(postgresMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version)
+ }
+
+ if version == 0 {
+ if _, err := tx.Exec(postgresSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %s", err)
+ }
+ } else {
+ for i := version; i < len(postgresMigrations); i++ {
+ if _, err := tx.Exec(postgresMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
+ ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) Close() error {
+ return db.db.Close()
+}
+
+func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM "User") AS users,
+ (SELECT COUNT(*) FROM "Network") AS networks,
+ (SELECT COUNT(*) FROM "Channel") AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ `SELECT id, username, password, admin, realname FROM "User"`)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ password := toNullString(user.Password)
+ realname := toNullString(user.Realname)
+
+ var err error
+ if user.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "User" (username, password, admin, realname)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`,
+ user.Username, password, user.Admin, realname).Scan(&user.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "User"
+ SET password = $1, admin = $2, realname = $3
+ WHERE id = $4`,
+ password, user.Admin, realname, user.ID)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
+ sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
+ FROM "Network"
+ WHERE "user" = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ netName := toNullString(network.Name)
+ nick := toNullString(network.Nick)
+ netUsername := toNullString(network.Username)
+ realname := toNullString(network.Realname)
+ pass := toNullString(network.Pass)
+ connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ var err error
+ if network.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
+ sasl_external_key, enabled)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ RETURNING id`,
+ userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Network"
+ SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
+ connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
+ sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
+ enabled = $14
+ WHERE id = $1`,
+ network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
+ detach_on
+ FROM "Channel"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ key := toNullString(ch.Key)
+ detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
+
+ var err error
+ if ch.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
+ detach_after, detach_on)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+ RETURNING id`,
+ networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Channel"
+ SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
+ relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
+ WHERE id = $1`,
+ ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM "DeliveryReceipt"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx,
+ `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
+ networkID, client)
+ if err != nil {
+ return err
+ }
+
+ stmt, err := tx.PrepareContext(ctx, `
+ INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`)
+ if err != nil {
+ return err
+ }
+ defer stmt.Close()
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+ err := stmt.
+ QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
+ Scan(&rcpt.ID)
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`,
+ networkID, name)
+ if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ return receipt, nil
+}
+
+func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "ReadReceipt"
+ SET timestamp = $1
+ WHERE id = $2`,
+ receipt.Timestamp, receipt.ID)
+ } else {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "ReadReceipt" (network, target, timestamp)
+ VALUES ($1, $2, $3)
+ RETURNING id`,
+ networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID)
+ }
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "os"
+ "testing"
+)
+
+// PostgreSQL version 0 schema. DO NOT EDIT.
+const postgresV0Schema = `
+CREATE TABLE "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+
+INSERT INTO "Config" (id, version) VALUES (1, 1);
+
+CREATE TABLE "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TABLE "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA DEFAULT NULL,
+ sasl_external_key BYTEA DEFAULT NULL,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+
+CREATE TABLE "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+
+CREATE TABLE "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+`
+
+func openTempPostgresDB(t *testing.T) *sql.DB {
+ source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
+ if !ok {
+ t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
+ }
+
+ db, err := sql.Open("postgres", source)
+ if err != nil {
+ t.Fatalf("failed to connect to PostgreSQL: %v", err)
+ }
+
+ // Store all tables in a temporary schema which will be dropped when the
+ // connection to PostgreSQL is closed.
+ db.SetMaxOpenConns(1)
+ if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
+ t.Fatalf("failed to set PostgreSQL search_path: %v", err)
+ }
+
+ return db
+}
+
+func TestPostgresMigrations(t *testing.T) {
+ sqlDB := openTempPostgresDB(t)
+ if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &PostgresDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("PostgresDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "fmt"
+ "math"
+ "strings"
+ "sync"
+ "time"
+
+ _ "modernc.org/sqlite"
+)
+
+const sqliteQueryTimeout = 5 * time.Second
+
+//go:embed suika_sqlite_schema.sql
+var sqliteSchema string
+
+var sqliteMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
+ "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
+ "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
+ "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
+ "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
+ `
+ CREATE TABLE IF NOT EXISTS UserNew (
+ id INTEGER PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin INTEGER NOT NULL DEFAULT 0
+ );
+ INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
+ DROP TABLE User;
+ ALTER TABLE UserNew RENAME TO User;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user INTEGER NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BLOB DEFAULT NULL,
+ sasl_external_key BLOB DEFAULT NULL,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT Network.id, name, User.id as user, addr, nick,
+ Network.username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key
+ FROM Network
+ JOIN User ON Network.user = User.username;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255),
+ internal_msgid VARCHAR(255) NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+ );
+ `,
+ "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
+ "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
+ "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT id, name, user, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key,
+ enabled
+ FROM Network;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type SqliteDB struct {
+ lock sync.RWMutex
+ db *sql.DB
+}
+
+func OpenSqliteDB(source string) (Database, error) {
+ sqlSqliteDB, err := sql.Open("sqlite", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &SqliteDB{db: sqlSqliteDB}
+ if err := db.upgrade(); err != nil {
+ sqlSqliteDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *SqliteDB) Close() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+ return db.db.Close()
+}
+
+func (db *SqliteDB) upgrade() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ var version int
+ if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
+ return fmt.Errorf("failed to query schema version: %v", err)
+ }
+
+ if version == len(sqliteMigrations) {
+ return nil
+ } else if version > len(sqliteMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version)
+ }
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if version == 0 {
+ if _, err := tx.Exec(sqliteSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %v", err)
+ }
+ } else {
+ for i := version; i < len(sqliteMigrations); i++ {
+ if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ // For some reason prepared statements don't work here
+ _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM User) AS users,
+ (SELECT COUNT(*) FROM Network) AS networks,
+ (SELECT COUNT(*) FROM Channel) AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func toNullString(s string) sql.NullString {
+ return sql.NullString{
+ String: s,
+ Valid: s != "",
+ }
+}
+
+func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ "SELECT id, username, password, admin, realname FROM User")
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ "SELECT id, password, admin, realname FROM User WHERE username = ?",
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("username", user.Username),
+ sql.Named("password", toNullString(user.Password)),
+ sql.Named("admin", user.Admin),
+ sql.Named("realname", toNullString(user.Realname)),
+ }
+
+ var err error
+ if user.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE User SET password = :password, admin = :admin,
+ realname = :realname WHERE username = :username`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ User(username, password, admin, realname)
+ VALUES (:username, :password, :admin, :realname)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ user.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
+
+func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt
+ WHERE id IN (
+ SELECT DeliveryReceipt.id
+ FROM DeliveryReceipt
+ JOIN Network ON DeliveryReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt
+ WHERE id IN (
+ SELECT ReadReceipt.id
+ FROM ReadReceipt
+ JOIN Network ON ReadReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM Channel
+ WHERE id IN (
+ SELECT Channel.id
+ FROM Channel
+ JOIN Network ON Channel.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key, enabled
+ FROM Network
+ WHERE user = ?`,
+ userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ args := []interface{}{
+ sql.Named("name", toNullString(network.Name)),
+ sql.Named("addr", network.Addr),
+ sql.Named("nick", toNullString(network.Nick)),
+ sql.Named("username", toNullString(network.Username)),
+ sql.Named("realname", toNullString(network.Realname)),
+ sql.Named("pass", toNullString(network.Pass)),
+ sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
+ sql.Named("sasl_mechanism", saslMechanism),
+ sql.Named("sasl_plain_username", saslPlainUsername),
+ sql.Named("sasl_plain_password", saslPlainPassword),
+ sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
+ sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
+ sql.Named("enabled", network.Enabled),
+
+ sql.Named("id", network.ID), // only for UPDATE
+ sql.Named("user", userID), // only for INSERT
+ }
+
+ var err error
+ if network.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE Network
+ SET name = :name, addr = :addr, nick = :nick, username = :username,
+ realname = :realname, pass = :pass, connect_commands = :connect_commands,
+ sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
+ sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
+ enabled = :enabled
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO Network(user, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
+ VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
+ :connect_commands, :sasl_mechanism, :sasl_plain_username,
+ :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ network.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `SELECT
+ id, name, key, detached, detached_internal_msgid,
+ relay_detached, reattach_on, detach_after, detach_on
+ FROM Channel
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("network", networkID),
+ sql.Named("name", ch.Name),
+ sql.Named("key", toNullString(ch.Key)),
+ sql.Named("detached", ch.Detached),
+ sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)),
+ sql.Named("relay_detached", ch.RelayDetached),
+ sql.Named("reattach_on", ch.ReattachOn),
+ sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))),
+ sql.Named("detach_on", ch.DetachOn),
+
+ sql.Named("id", ch.ID), // only for UPDATE
+ }
+
+ var err error
+ if ch.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `UPDATE Channel
+ SET network = :network, name = :name, key = :key, detached = :detached,
+ detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached,
+ reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
+ VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...)
+ if err != nil {
+ return err
+ }
+ ch.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
+ return err
+}
+
+func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM DeliveryReceipt
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ var client sql.NullString
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ rcpt.Client = client.String
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
+ networkID, toNullString(client))
+ if err != nil {
+ return err
+ }
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+
+ res, err := tx.ExecContext(ctx, `
+ INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
+ VALUES (:network, :target, :client, :internal_msgid)`,
+ sql.Named("network", networkID),
+ sql.Named("target", rcpt.Target),
+ sql.Named("client", toNullString(client)),
+ sql.Named("internal_msgid", rcpt.InternalMsgID))
+ if err != nil {
+ return err
+ }
+ rcpt.ID, err = res.LastInsertId()
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx, `
+ SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`,
+ sql.Named("network", networkID),
+ sql.Named("target", name),
+ )
+ var timestamp string
+ if err := row.Scan(&receipt.ID, ×tamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if t, err := time.Parse(serverTimeLayout, timestamp); err != nil {
+ return nil, err
+ } else {
+ receipt.Timestamp = t
+ }
+ return receipt, nil
+}
+
+func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("id", receipt.ID),
+ sql.Named("timestamp", formatServerTime(receipt.Timestamp)),
+ sql.Named("network", networkID),
+ sql.Named("target", receipt.Target),
+ }
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ ReadReceipt(network, target, timestamp)
+ VALUES (:network, :target, :timestamp)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ receipt.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "testing"
+)
+
+// SQLite version 0 schema. DO NOT EDIT.
+const sqliteV0Schema = `
+CREATE TABLE User (
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255)
+);
+
+CREATE TABLE Network (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user VARCHAR(255) NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+
+CREATE TABLE Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+PRAGMA user_version = 1;
+`
+
+func TestSqliteMigrations(t *testing.T) {
+ sqlDB, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+
+ if _, err := sqlDB.Exec(sqliteV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &SqliteDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("SqliteDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go.
+//
+// # Copyright (C) 2020 The soju Contributors
+// # Copyright (C) 2023-present Izuru Yakumo et al.
+//
+// suika is covered by the AGPLv3 license:
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <https://www.gnu.org/licenses/>.
+package suika
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-CONFIG 5
+.Os
+.Sh NAME
+.Nm suika-config
+.Nd Configuration file for the IRC bouncer
+.Sh SYNOPSIS
+.Bk -words
+listen ircs://
+.Pp
+tls cert.pem key.pem
+.Pp
+hostname example.org
+.Ek
+.Sh DESCRIPTION
+This document describes the format of the configuration
+file used by
+.Xr suika 1
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It listen Ar uri
+With this you can control on what
+ports/protocols
+.Xr suika 1
+listens on, it supports
+irc (cleartext IRC), ircs (IRC with TLS), and unix
+(IRC over Unix domain sockets)
+.It hostname Ar hostname
+Server hostname, if unset, the system one is used.
+.It title Ar title
+Server title, this will be sent as the ISUPPORT NETWORK value when
+clients don't select a specific network.
+.It tls Ar cert Ar key
+Enable TLS support, the certificate and key files must be
+PEM-encoded.
+.It db Ar driver Ar path
+Set the database driver for user, network and channel storage.
+By default a SQLite 3 database is opened in
+.Pa ./suika.db
+Supported drivers are sqlite and postgres, the former
+expects a path to the database file, and the latter
+a space-separated list of key=value parameters,
+e.g. host=localhost dbname=suika
+.It log fs Ar path
+Path to the bouncer logs directory, or empty to disable
+logging.
+By default, logging is disabled.
+.It max-user-networks Ar limit
+Maximum number of networks per user, by default
+there is no limit.
+.It motd Ar path
+Path to the MOTD file, its contents are sent to clients
+which aren't bound to a particular network.
+By default, no MOTD is sent.
+.It multi-upstream-mode Ar bool
+Globally enable or disable multi-upstream mode.
+By default, it is enabled.
+.It upstream-user-ip Ar cidr
+Enable per-user-IP addresses.
+One IPv4 and/or one IPv6 range can be specified in CIDR notation.
+One IP address per range will be assigned to each user as the
+source address when connecting to an upstream network.
+This can be useful to avoid having the whole bouncer banned from
+an upstream network because of one malicious user.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-ZNC-IMPORT 1
+.Os
+.Sh NAME
+.Nm suika-znc-import
+.Nd Migration utility for moving from ZNC
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar suika config file
+.Op Fl user Ar username
+.Op Fl network Ar name
+.Sh DESCRIPTION
+Imports configuration from a ZNC file.
+Users and networks are merged if they already exist in the
+.Xr suika 1
+database.
+ZNC settings overwrite existing
+.Xr suika 1
+settings
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It config Ar suika config file
+Path to
+.Xr suika-config 5
+.It user Ar username
+Limit import to username, may be specified multiple times.
+It network Ar name
+Limit import to network, may be specified multiple times.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA 1
+.Os
+.Sh NAME
+.Nm suika
+.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar path
+.Op Fl debug
+.Op Fl listen Ar uri
+.Sh DESCRIPTION
+.Nm
+is an user-friendly IRC bouncer, it connects to upstream
+IRC servers on behalf of the user to provide extra features.
+.Bl -tag -width 6n
+.It Multiple separate users sharing the same bouncer
+.It Clients connecting to multiple upstream servers (via a single connection)
+.It Sending the backlog with per-client buffers
+.El
+.Pp
+When joining a channel, the channel will be saved
+and automatically joined on the next connection.
+When registering or authenticating with NickServ, the credentials will be saved
+and automatically used on the next connection if the server supports SASL.
+When parting a channel with the reason "detach", the channel will be
+detached instead of being left.
+When all clients are disconnected from the bouncer,
+the user is automatically marked as away.
+.Pp
+.Nm
+supports two connection modes:
+.Bl -tag -width 6n
+.It Single upstream mode
+One downstream connection maps to one upstream connection
+.Pp
+To enable this mode, connect to the bouncer
+with the username "<username>/<network>".
+.Pp
+If the bouncer isn't connected to the upstream server,
+it will get automatically added.
+.Pp
+Then channels can be joined and parted as if
+you were directly connected to the upstream server.
+.It Multiple upstream mode
+One downstream connection maps to multiple upstream connections.
+Channels and nicks are suffixed with the network name.
+To join a channel, you need to use the suffix too: /join #channel/network.
+Same applies to messages sent to users.
+.El
+.Pp
+For per-client history to work, clients need to indicate their name.
+This can be done by adding a "@<client>" suffix to the username.
+.Pp
+.Nm
+will reload the configuration file, the TLS certificate/key and
+the MOTD file when it receives the HUP signal.
+The configuration options listen, db and log cannot be reloaded.
+.Pp
+Administrators can broadcast a message to all bouncer users via
+/notice $<hostname> <text>, or via /notice $<text> in multi-upstream mode.
+All currently connected bouncer users will receive the message
+from the special BouncerServ service.
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKADB 1
+.Os
+.Sh NAME
+.Nm suikadb
+.Nd Basic user manipulation for
+.Xr suika 1
+.Sh SYNOPSIS
+.Nm
+.Op create-user
+.Op change-password
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+type ircError struct {
+ Message *irc.Message
+}
+
+func (err ircError) Error() string {
+ return err.Message.String()
+}
+
+func newUnknownCommandError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{
+ "*",
+ cmd,
+ "Unknown command",
+ },
+ }}
+}
+
+func newNeedMoreParamsError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NEEDMOREPARAMS,
+ Params: []string{
+ "*",
+ cmd,
+ "Not enough parameters",
+ },
+ }}
+}
+
+func newChatHistoryError(subcommand string, target string) ircError {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"},
+ }}
+}
+
+// authError is an authentication error.
+type authError struct {
+ // Internal error cause. This will not be revealed to the user.
+ err error
+ // Error cause which can safely be sent to the user without compromising
+ // security.
+ reason string
+}
+
+func (err *authError) Error() string {
+ return err.err.Error()
+}
+
+func (err *authError) Unwrap() error {
+ return err.err
+}
+
+// authErrorReason returns the user-friendly reason of an authentication
+// failure.
+func authErrorReason(err error) string {
+ if authErr, ok := err.(*authError); ok {
+ return authErr.reason
+ } else {
+ return "Authentication failed"
+ }
+}
+
+func newInvalidUsernameOrPasswordError(err error) error {
+ return &authError{
+ err: err,
+ reason: "Invalid username or password",
+ }
+}
+
+func parseBouncerNetID(subcommand, s string) (int64, error) {
+ id, err := strconv.ParseInt(s, 10, 64)
+ if err != nil {
+ return 0, ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"},
+ }}
+ }
+ return id, nil
+}
+
+func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) {
+ u, err := network.URL()
+ if err != nil {
+ return
+ }
+
+ hasHostPort := true
+ switch u.Scheme {
+ case "ircs":
+ attrs["tls"] = irc.TagValue("1")
+ case "irc":
+ attrs["tls"] = irc.TagValue("0")
+ default: // e.g. unix://
+ hasHostPort = false
+ }
+ if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort {
+ attrs["host"] = irc.TagValue(host)
+ attrs["port"] = irc.TagValue(port)
+ } else if hasHostPort {
+ attrs["host"] = irc.TagValue(u.Host)
+ }
+}
+
+func getNetworkAttrs(network *network) irc.Tags {
+ state := "disconnected"
+ if uc := network.conn; uc != nil {
+ state = "connected"
+ }
+
+ attrs := irc.Tags{
+ "name": irc.TagValue(network.GetName()),
+ "state": irc.TagValue(state),
+ "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)),
+ }
+
+ if network.Username != "" {
+ attrs["username"] = irc.TagValue(network.Username)
+ }
+ if realname := GetRealname(&network.user.User, &network.Network); realname != "" {
+ attrs["realname"] = irc.TagValue(realname)
+ }
+
+ fillNetworkAddrAttrs(attrs, &network.Network)
+
+ return attrs
+}
+
+func networkAddrFromAttrs(attrs irc.Tags) string {
+ host, ok := attrs.GetTag("host")
+ if !ok {
+ return ""
+ }
+
+ addr := host
+ if port, ok := attrs.GetTag("port"); ok {
+ addr += ":" + port
+ }
+
+ if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" {
+ addr = "irc://" + tlsStr
+ }
+
+ return addr
+}
+
+func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error {
+ addrAttrs := irc.Tags{}
+ fillNetworkAddrAttrs(addrAttrs, record)
+
+ updateAddr := false
+ for k, v := range attrs {
+ s := string(v)
+ switch k {
+ case "host", "port", "tls":
+ updateAddr = true
+ addrAttrs[k] = v
+ case "name":
+ record.Name = s
+ case "nickname":
+ record.Nick = s
+ case "username":
+ record.Username = s
+ case "realname":
+ record.Realname = s
+ case "pass":
+ record.Pass = s
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"},
+ }}
+ }
+ }
+
+ if updateAddr {
+ record.Addr = networkAddrFromAttrs(addrAttrs)
+ if record.Addr == "" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"},
+ }}
+ }
+ }
+
+ return nil
+}
+
+// illegalNickChars is the list of characters forbidden in a nickname.
+//
+// ' ' and ':' break the IRC message wire format
+// '@' and '!' break prefixes
+// '*' breaks masks and is the reserved nickname for registration
+// '?' breaks masks
+// '$' breaks server masks in PRIVMSG/NOTICE
+// ',' breaks lists
+// '.' is reserved for server names
+const illegalNickChars = " :@!*?$,."
+
+// permanentDownstreamCaps is the list of always-supported downstream
+// capabilities.
+var permanentDownstreamCaps = map[string]string{
+ "batch": "",
+ "cap-notify": "",
+ "echo-message": "",
+ "invite-notify": "",
+ "message-tags": "",
+ "server-time": "",
+ "setname": "",
+
+ "soju.im/bouncer-networks": "",
+ "soju.im/bouncer-networks-notify": "",
+ "soju.im/read": "",
+}
+
+// needAllDownstreamCaps is the list of downstream capabilities that
+// require support from all upstreams to be enabled
+var needAllDownstreamCaps = map[string]string{
+ "account-notify": "",
+ "account-tag": "",
+ "away-notify": "",
+ "extended-join": "",
+ "multi-prefix": "",
+
+ "draft/extended-monitor": "",
+}
+
+// passthroughIsupport is the set of ISUPPORT tokens that are directly passed
+// through from the upstream server to downstream clients.
+//
+// This is only effective in single-upstream mode.
+var passthroughIsupport = map[string]bool{
+ "AWAYLEN": true,
+ "BOT": true,
+ "CHANLIMIT": true,
+ "CHANMODES": true,
+ "CHANNELLEN": true,
+ "CHANTYPES": true,
+ "CLIENTTAGDENY": true,
+ "ELIST": true,
+ "EXCEPTS": true,
+ "EXTBAN": true,
+ "HOSTLEN": true,
+ "INVEX": true,
+ "KICKLEN": true,
+ "MAXLIST": true,
+ "MAXTARGETS": true,
+ "MODES": true,
+ "MONITOR": true,
+ "NAMELEN": true,
+ "NETWORK": true,
+ "NICKLEN": true,
+ "PREFIX": true,
+ "SAFELIST": true,
+ "TARGMAX": true,
+ "TOPICLEN": true,
+ "USERLEN": true,
+ "UTF8ONLY": true,
+ "WHOX": true,
+}
+
+type downstreamSASL struct {
+ server sasl.Server
+ plainUsername, plainPassword string
+ pendingResp bytes.Buffer
+}
+
+type downstreamConn struct {
+ conn
+
+ id uint64
+
+ registered bool
+ user *user
+ nick string
+ nickCM string
+ rawUsername string
+ networkName string
+ clientName string
+ realname string
+ hostname string
+ account string // RPL_LOGGEDIN/OUT state
+ password string // empty after authentication
+ network *network // can be nil
+ isMultiUpstream bool
+
+ negotiatingCaps bool
+ capVersion int
+ supportedCaps map[string]string
+ caps map[string]bool
+ sasl *downstreamSASL
+
+ lastBatchRef uint64
+
+ monitored casemapMap
+}
+
+func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
+ remoteAddr := ic.RemoteAddr().String()
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
+ options := connOptions{Logger: logger}
+ dc := &downstreamConn{
+ conn: *newConn(srv, ic, &options),
+ id: id,
+ nick: "*",
+ nickCM: "*",
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ monitored: newCasemapMap(0),
+ }
+ dc.hostname = remoteAddr
+ if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
+ dc.hostname = host
+ }
+ for k, v := range permanentDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ dc.supportedCaps["sasl"] = "PLAIN"
+ // TODO: this is racy, we should only enable chathistory after
+ // authentication and then check that user.msgStore implements
+ // chatHistoryMessageStore
+ if srv.Config().LogPath != "" {
+ dc.supportedCaps["draft/chathistory"] = ""
+ }
+ return dc
+}
+
+func (dc *downstreamConn) prefix() *irc.Prefix {
+ return &irc.Prefix{
+ Name: dc.nick,
+ User: dc.user.Username,
+ Host: dc.hostname,
+ }
+}
+
+func (dc *downstreamConn) forEachNetwork(f func(*network)) {
+ if dc.network != nil {
+ f(dc.network)
+ } else if dc.isMultiUpstream {
+ for _, network := range dc.user.networks {
+ f(network)
+ }
+ }
+}
+
+func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ dc.user.forEachUpstream(func(uc *upstreamConn) {
+ if dc.network != nil && uc.network != dc.network {
+ return
+ }
+ f(uc)
+ })
+}
+
+// upstream returns the upstream connection, if any. If there are zero or if
+// there are multiple upstream connections, it returns nil.
+func (dc *downstreamConn) upstream() *upstreamConn {
+ if dc.network == nil {
+ return nil
+ }
+ return dc.network.conn
+}
+
+func isOurNick(net *network, nick string) bool {
+ // TODO: this doesn't account for nick changes
+ if net.conn != nil {
+ return net.casemap(nick) == net.conn.nickCM
+ }
+ // We're not currently connected to the upstream connection, so we don't
+ // know whether this name is our nickname. Best-effort: use the network's
+ // configured nickname and hope it was the one being used when we were
+ // connected.
+ return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network))
+}
+
+// marshalEntity converts an upstream entity name (ie. channel or nick) into a
+// downstream entity name.
+//
+// This involves adding a "/<network>" suffix if the entity isn't the current
+// user.
+func (dc *downstreamConn) marshalEntity(net *network, name string) string {
+ if isOurNick(net, name) {
+ return dc.nick
+ }
+ name = partialCasemap(net.casemap, name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal an entity for another network")
+ }
+ return name
+ }
+ return name + "/" + net.GetName()
+}
+
+func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix {
+ if isOurNick(net, prefix.Name) {
+ return dc.prefix()
+ }
+ prefix.Name = partialCasemap(net.casemap, prefix.Name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal a user prefix for another network")
+ }
+ return prefix
+ }
+ return &irc.Prefix{
+ Name: prefix.Name + "/" + net.GetName(),
+ User: prefix.User,
+ Host: prefix.Host,
+ }
+}
+
+// unmarshalEntityNetwork converts a downstream entity name (ie. channel or
+// nick) into an upstream entity name.
+//
+// This involves removing the "/<network>" suffix.
+func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) {
+ if dc.network != nil {
+ return dc.network, name, nil
+ }
+ if !dc.isMultiUpstream {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"},
+ }}
+ }
+
+ var net *network
+ if i := strings.LastIndexByte(name, '/'); i >= 0 {
+ network := name[i+1:]
+ name = name[:i]
+
+ for _, n := range dc.user.networks {
+ if network == n.GetName() {
+ net = n
+ break
+ }
+ }
+ }
+
+ if net == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Missing network suffix in name"},
+ }}
+ }
+
+ return net, name, nil
+}
+
+// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the
+// upstream connection and fails if the upstream is disconnected.
+func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) {
+ net, name, err := dc.unmarshalEntityNetwork(name)
+ if err != nil {
+ return nil, "", err
+ }
+
+ if net.conn == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Disconnected from upstream network"},
+ }}
+ }
+
+ return net.conn, name, nil
+}
+
+func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string {
+ if dc.upstream() != nil {
+ return text
+ }
+ // TODO: smarter parsing that ignores URLs
+ return strings.ReplaceAll(text, "/"+uc.network.GetName(), "")
+}
+
+func (dc *downstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := dc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (dc *downstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := dc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventDownstreamMessage{msg, dc}
+ }
+
+ return nil
+}
+
+// SendMessage sends an outgoing message.
+//
+// This can only called from the user goroutine.
+func (dc *downstreamConn) SendMessage(msg *irc.Message) {
+ if !dc.caps["message-tags"] {
+ if msg.Command == "TAGMSG" {
+ return
+ }
+ msg = msg.Copy()
+ for name := range msg.Tags {
+ supported := false
+ switch name {
+ case "time":
+ supported = dc.caps["server-time"]
+ case "account":
+ supported = dc.caps["account"]
+ }
+ if !supported {
+ delete(msg.Tags, name)
+ }
+ }
+ }
+ if !dc.caps["batch"] && msg.Tags["batch"] != "" {
+ msg = msg.Copy()
+ delete(msg.Tags, "batch")
+ }
+ if msg.Command == "JOIN" && !dc.caps["extended-join"] {
+ msg.Params = msg.Params[:1]
+ }
+ if msg.Command == "SETNAME" && !dc.caps["setname"] {
+ return
+ }
+ if msg.Command == "AWAY" && !dc.caps["away-notify"] {
+ return
+ }
+ if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] {
+ return
+ }
+ if msg.Command == "READ" && !dc.caps["soju.im/read"] {
+ return
+ }
+
+ dc.conn.SendMessage(context.TODO(), msg)
+}
+
+func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) {
+ dc.lastBatchRef++
+ ref := fmt.Sprintf("%v", dc.lastBatchRef)
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Tags: tags,
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: append([]string{"+" + ref, typ}, params...),
+ })
+ }
+
+ f(irc.TagValue(ref))
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: []string{"-" + ref},
+ })
+ }
+}
+
+// sendMessageWithID sends an outgoing message with the specified internal ID.
+func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
+ dc.SendMessage(msg)
+
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// advanceMessageWithID advances history to the specified message ID without
+// sending a message. This is useful e.g. for self-messages when echo-message
+// isn't enabled.
+func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// ackMsgID acknowledges that a message has been received.
+func (dc *downstreamConn) ackMsgID(id string) {
+ netID, entity, err := parseMsgID(id, nil)
+ if err != nil {
+ dc.logger.Printf("failed to ACK message ID %q: %v", id, err)
+ return
+ }
+
+ network := dc.user.getNetworkByID(netID)
+ if network == nil {
+ return
+ }
+
+ network.delivered.StoreID(entity, dc.clientName, id)
+}
+
+func (dc *downstreamConn) sendPing(msgID string) {
+ token := "suika-msgid-" + msgID
+ dc.SendMessage(&irc.Message{
+ Command: "PING",
+ Params: []string{token},
+ })
+}
+
+func (dc *downstreamConn) handlePong(token string) {
+ if !strings.HasPrefix(token, "suika-msgid-") {
+ dc.logger.Printf("received unrecognized PONG token %q", token)
+ return
+ }
+ msgID := strings.TrimPrefix(token, "suika-msgid-")
+ dc.ackMsgID(msgID)
+}
+
+// marshalMessage re-formats a message coming from an upstream connection so
+// that it's suitable for being sent on this downstream connection. Only
+// messages that may appear in logs are supported, except MODE messages which
+// may only appear in single-upstream mode.
+func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message {
+ msg = msg.Copy()
+ msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix)
+
+ if dc.network != nil {
+ return msg
+ }
+
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE", "TAGMSG":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "NICK":
+ // Nick change for another user
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "JOIN", "PART":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "KICK":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ msg.Params[1] = dc.marshalEntity(net, msg.Params[1])
+ case "TOPIC":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "QUIT", "SETNAME":
+ // This space is intentionally left blank
+ default:
+ panic(fmt.Sprintf("unexpected %q message", msg.Command))
+ }
+
+ return msg
+}
+
+func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ ctx, cancel := dc.conn.NewContext(ctx)
+ defer cancel()
+
+ ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout)
+ defer cancel()
+
+ switch msg.Command {
+ case "QUIT":
+ return dc.Close()
+ default:
+ if dc.registered {
+ return dc.handleMessageRegistered(ctx, msg)
+ } else {
+ return dc.handleMessageUnregistered(ctx, msg)
+ }
+ }
+}
+
+func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "NICK":
+ var nick string
+ if err := parseMessageParams(msg, &nick); err != nil {
+ return err
+ }
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, nick, "contains illegal characters"},
+ }}
+ }
+ nickCM := casemapASCII(nick)
+ if nickCM == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
+ }}
+ }
+ dc.nick = nick
+ dc.nickCM = nickCM
+ case "USER":
+ if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil {
+ return err
+ }
+ case "PASS":
+ if err := parseMessageParams(msg, &dc.password); err != nil {
+ return err
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "AUTHENTICATE":
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ } else if credentials == nil {
+ break
+ }
+
+ if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil {
+ dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err)
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, authErrorReason(err)},
+ })
+ break
+ }
+
+ // Technically we should send RPL_LOGGEDIN here. However we use
+ // RPL_LOGGEDIN to mirror the upstream connection status. Let's
+ // see how many clients that breaks. See:
+ // https://github.com/ircv3/ircv3-specifications/pull/476
+ dc.endSASL(nil)
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+
+ if dc.user == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
+ }}
+ }
+
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ var match *network
+ for _, net := range dc.user.networks {
+ if net.ID == id {
+ match = net
+ break
+ }
+ }
+ if match == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"},
+ }}
+ }
+
+ dc.networkName = match.GetName()
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+ return newUnknownCommandError(msg.Command)
+ }
+ if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps {
+ return dc.register(ctx)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
+ cmd = strings.ToUpper(cmd)
+
+ switch cmd {
+ case "LS":
+ if len(args) > 0 {
+ var err error
+ if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
+ return err
+ }
+ }
+ if !dc.registered && dc.capVersion >= 302 {
+ // Let downstream show everything it supports, and trim
+ // down the available capabilities when upstreams are
+ // known.
+ for k, v := range needAllDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ }
+
+ caps := make([]string, 0, len(dc.supportedCaps))
+ for k, v := range dc.supportedCaps {
+ if dc.capVersion >= 302 && v != "" {
+ caps = append(caps, k+"="+v)
+ } else {
+ caps = append(caps, k)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LS", strings.Join(caps, " ")},
+ })
+
+ if dc.capVersion >= 302 {
+ // CAP version 302 implicitly enables cap-notify
+ dc.caps["cap-notify"] = true
+ }
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "LIST":
+ var caps []string
+ for name, enabled := range dc.caps {
+ if enabled {
+ caps = append(caps, name)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LIST", strings.Join(caps, " ")},
+ })
+ case "REQ":
+ if len(args) == 0 {
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"},
+ }}
+ }
+
+ // TODO: atomically ack/nak the whole capability set
+ caps := strings.Fields(args[0])
+ ack := true
+ for _, name := range caps {
+ name = strings.ToLower(name)
+ enable := !strings.HasPrefix(name, "-")
+ if !enable {
+ name = strings.TrimPrefix(name, "-")
+ }
+
+ if enable == dc.caps[name] {
+ continue
+ }
+
+ _, ok := dc.supportedCaps[name]
+ if !ok {
+ ack = false
+ break
+ }
+
+ if name == "cap-notify" && dc.capVersion >= 302 && !enable {
+ // cap-notify cannot be disabled with CAP version 302
+ ack = false
+ break
+ }
+
+ dc.caps[name] = enable
+ }
+
+ reply := "NAK"
+ if ack {
+ reply = "ACK"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, reply, args[0]},
+ })
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "END":
+ dc.negotiatingCaps = false
+ default:
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Unknown CAP command"},
+ }}
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) {
+ defer func() {
+ if err != nil {
+ dc.sasl = nil
+ }
+ }()
+
+ if !dc.caps["sasl"] {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"},
+ }}
+ }
+ if len(msg.Params) == 0 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Missing AUTHENTICATE argument"},
+ }}
+ }
+ if msg.Params[0] == "*" {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ }}
+ }
+
+ var resp []byte
+ if dc.sasl == nil {
+ mech := strings.ToUpper(msg.Params[0])
+ var server sasl.Server
+ switch mech {
+ case "PLAIN":
+ server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
+ dc.sasl.plainUsername = username
+ dc.sasl.plainPassword = password
+ return nil
+ }))
+ default:
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)},
+ }}
+ }
+
+ dc.sasl = &downstreamSASL{server: server}
+ } else {
+ chunk := msg.Params[0]
+ if chunk == "+" {
+ chunk = ""
+ }
+
+ if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Response too long"},
+ }}
+ }
+
+ dc.sasl.pendingResp.WriteString(chunk)
+
+ if len(chunk) == maxSASLLength {
+ return nil, nil // Multi-line response, wait for the next command
+ }
+
+ resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String())
+ if err != nil {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Invalid base64-encoded response"},
+ }}
+ }
+
+ dc.sasl.pendingResp.Reset()
+ }
+
+ challenge, done, err := dc.sasl.server.Next(resp)
+ if err != nil {
+ return nil, err
+ } else if done {
+ return dc.sasl, nil
+ } else {
+ challengeStr := "+"
+ if len(challenge) > 0 {
+ challengeStr = base64.StdEncoding.EncodeToString(challenge)
+ }
+
+ // TODO: multi-line messages
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "AUTHENTICATE",
+ Params: []string{challengeStr},
+ })
+ return nil, nil
+ }
+}
+
+func (dc *downstreamConn) endSASL(msg *irc.Message) {
+ if dc.sasl == nil {
+ return
+ }
+
+ dc.sasl = nil
+
+ if msg != nil {
+ dc.SendMessage(msg)
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_SASLSUCCESS,
+ Params: []string{dc.nick, "SASL authentication successful"},
+ })
+ }
+}
+
+func (dc *downstreamConn) setSupportedCap(name, value string) {
+ prevValue, hasPrev := dc.supportedCaps[name]
+ changed := !hasPrev || prevValue != value
+ dc.supportedCaps[name] = value
+
+ if !dc.caps["cap-notify"] || !changed {
+ return
+ }
+
+ cap := name
+ if value != "" && dc.capVersion >= 302 {
+ cap = name + "=" + value
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "NEW", cap},
+ })
+}
+
+func (dc *downstreamConn) unsetSupportedCap(name string) {
+ _, hasPrev := dc.supportedCaps[name]
+ delete(dc.supportedCaps, name)
+ delete(dc.caps, name)
+
+ if !dc.caps["cap-notify"] || !hasPrev {
+ return
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "DEL", name},
+ })
+}
+
+func (dc *downstreamConn) updateSupportedCaps() {
+ supportedCaps := make(map[string]bool)
+ for cap := range needAllDownstreamCaps {
+ supportedCaps[cap] = true
+ }
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for cap, supported := range supportedCaps {
+ supportedCaps[cap] = supported && uc.caps[cap]
+ }
+ })
+
+ for cap, supported := range supportedCaps {
+ if supported {
+ dc.setSupportedCap(cap, needAllDownstreamCaps[cap])
+ } else {
+ dc.unsetSupportedCap(cap)
+ }
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") {
+ dc.setSupportedCap("sasl", "PLAIN")
+ } else if dc.network != nil {
+ dc.unsetSupportedCap("sasl")
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] {
+ // Strip "before-connect", because we require downstreams to be fully
+ // connected before attempting account registration.
+ values := strings.Split(uc.supportedCaps["draft/account-registration"], ",")
+ for i, v := range values {
+ if v == "before-connect" {
+ values = append(values[:i], values[i+1:]...)
+ break
+ }
+ }
+ dc.setSupportedCap("draft/account-registration", strings.Join(values, ","))
+ } else {
+ dc.unsetSupportedCap("draft/account-registration")
+ }
+
+ if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil {
+ dc.setSupportedCap("draft/event-playback", "")
+ } else {
+ dc.unsetSupportedCap("draft/event-playback")
+ }
+}
+
+func (dc *downstreamConn) updateNick() {
+ if uc := dc.upstream(); uc != nil && uc.nick != dc.nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ dc.nick = uc.nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+}
+
+func (dc *downstreamConn) updateRealname() {
+ if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{uc.realname},
+ })
+ dc.realname = uc.realname
+ }
+}
+
+func (dc *downstreamConn) updateAccount() {
+ var account string
+ if dc.network == nil {
+ account = dc.user.Username
+ } else if uc := dc.upstream(); uc != nil {
+ account = uc.account
+ } else {
+ return
+ }
+
+ if dc.account == account || !dc.caps["sasl"] {
+ return
+ }
+
+ if account != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDIN,
+ Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDOUT,
+ Params: []string{dc.nick, dc.prefix().String(), "You are logged out"},
+ })
+ }
+
+ dc.account = account
+}
+
+func sanityCheckServer(ctx context.Context, addr string) error {
+ ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
+ defer cancel()
+
+ conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ return conn.Close()
+}
+
+func unmarshalUsername(rawUsername string) (username, client, network string) {
+ username = rawUsername
+
+ i := strings.IndexAny(username, "/@")
+ j := strings.LastIndexAny(username, "/@")
+ if i >= 0 {
+ username = rawUsername[:i]
+ }
+ if j >= 0 {
+ if rawUsername[j] == '@' {
+ client = rawUsername[j+1:]
+ } else {
+ network = rawUsername[j+1:]
+ }
+ }
+ if i >= 0 && j >= 0 && i < j {
+ if rawUsername[i] == '@' {
+ client = rawUsername[i+1 : j]
+ } else {
+ network = rawUsername[i+1 : j]
+ }
+ }
+
+ return username, client, network
+}
+
+func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
+ username, clientName, networkName := unmarshalUsername(username)
+
+ u, err := dc.srv.db.GetUser(ctx, username)
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err))
+ }
+
+ // Password auth disabled
+ if u.Password == "" {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled"))
+ }
+
+ err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password"))
+ }
+
+ dc.user = dc.srv.getUser(username)
+ if dc.user == nil {
+ return fmt.Errorf("user not active")
+ }
+ dc.clientName = clientName
+ dc.networkName = networkName
+ return nil
+}
+
+func (dc *downstreamConn) register(ctx context.Context) error {
+ if dc.registered {
+ panic("tried to register twice")
+ }
+
+ if dc.sasl != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ }
+
+ password := dc.password
+ dc.password = ""
+ if dc.user == nil {
+ if password == "" {
+ if dc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"},
+ }}
+ } else {
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, "Authentication required"},
+ }}
+ }
+ }
+
+ if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil {
+ dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, authErrorReason(err)},
+ }}
+ }
+ }
+
+ _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername)
+ if dc.clientName == "" {
+ dc.clientName = fallbackClientName
+ } else if fallbackClientName != "" && dc.clientName != fallbackClientName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Client name mismatch in usernames"},
+ }}
+ }
+ if dc.networkName == "" {
+ dc.networkName = fallbackNetworkName
+ } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Network name mismatch in usernames"},
+ }}
+ }
+
+ dc.registered = true
+ dc.logger.Printf("registration complete for user %q", dc.user.Username)
+ return nil
+}
+
+func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
+ if dc.networkName == "" {
+ return nil
+ }
+
+ network := dc.user.getNetwork(dc.networkName)
+ if network == nil {
+ addr := dc.networkName
+ if !strings.ContainsRune(addr, ':') {
+ addr = addr + ":6697"
+ }
+
+ dc.logger.Printf("trying to connect to new network %q", addr)
+ if err := sanityCheckServer(ctx, addr); err != nil {
+ dc.logger.Printf("failed to connect to %q: %v", addr, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)},
+ }}
+ }
+
+ // Some clients only allow specifying the nickname (and use the
+ // nickname as a username too). Strip the network name from the
+ // nickname when auto-saving networks.
+ nick, _, _ := unmarshalUsername(dc.nick)
+
+ dc.logger.Printf("auto-saving network %q", dc.networkName)
+ var err error
+ network, err = dc.user.createNetwork(ctx, &Network{
+ Addr: dc.networkName,
+ Nick: nick,
+ Enabled: true,
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ dc.network = network
+ return nil
+}
+
+func (dc *downstreamConn) welcome(ctx context.Context) error {
+ if dc.user == nil || !dc.registered {
+ panic("tried to welcome an unregistered connection")
+ }
+
+ remoteAddr := dc.conn.RemoteAddr().String()
+ dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}
+
+ // TODO: doing this might take some time. We should do it in dc.register
+ // instead, but we'll potentially be adding a new network and this must be
+ // done in the user goroutine.
+ if err := dc.loadNetwork(ctx); err != nil {
+ return err
+ }
+
+ if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream {
+ dc.isMultiUpstream = true
+ }
+
+ dc.updateSupportedCaps()
+
+ isupport := []string{
+ fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit),
+ "CASEMAPPING=ascii",
+ }
+
+ if dc.network != nil {
+ isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID))
+ }
+ if title := dc.srv.Config().Title; dc.network == nil && title != "" {
+ isupport = append(isupport, "NETWORK="+encodeISUPPORT(title))
+ }
+ if dc.network == nil && !dc.isMultiUpstream {
+ isupport = append(isupport, "WHOX")
+ }
+
+ if uc := dc.upstream(); uc != nil {
+ for k := range passthroughIsupport {
+ v, ok := uc.isupport[k]
+ if !ok {
+ continue
+ }
+ if v != nil {
+ isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v))
+ } else {
+ isupport = append(isupport, k)
+ }
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WELCOME,
+ Params: []string{dc.nick, "Welcome to suika, " + dc.nick},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_YOURHOST,
+ Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MYINFO,
+ Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) {
+ dc.SendMessage(msg)
+ }
+ if uc := dc.upstream(); uc != nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + string(uc.modes)},
+ })
+ }
+ if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+o"},
+ })
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+
+ if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil {
+ for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) {
+ dc.SendMessage(msg)
+ }
+ } else {
+ motdHint := "No MOTD"
+ if dc.network != nil {
+ motdHint = "Use /motd to read the message of the day"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOMOTD,
+ Params: []string{dc.nick, motdHint},
+ })
+ }
+
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if !ch.complete {
+ continue
+ }
+ record := uc.network.channels.Value(ch.Name)
+ if record != nil && record.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)},
+ })
+
+ forwardChannel(ctx, dc, ch)
+ }
+ })
+
+ dc.forEachNetwork(func(net *network) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ // Only send history if we're the first connected client with that name
+ // for the network
+ firstClient := true
+ dc.user.forEachDownstream(func(c *downstreamConn) {
+ if c != dc && c.clientName == dc.clientName && c.network == dc.network {
+ firstClient = false
+ }
+ })
+ if firstClient {
+ net.delivered.ForEachTarget(func(target string) {
+ lastDelivered := net.delivered.LoadID(target, dc.clientName)
+ if lastDelivered == "" {
+ return
+ }
+
+ dc.sendTargetBacklog(ctx, net, target, lastDelivered)
+
+ // Fast-forward history to last message
+ targetCM := net.casemap(target)
+ lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now())
+ if err != nil {
+ dc.logger.Printf("failed to get last message ID: %v", err)
+ return
+ }
+ net.delivered.StoreID(target, dc.clientName, lastID)
+ })
+ }
+ })
+
+ return nil
+}
+
+// messageSupportsBacklog checks whether the provided message can be sent as
+// part of an history batch.
+func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool {
+ // Don't replay all messages, because that would mess up client
+ // state. For instance we just sent the list of users, sending
+ // PART messages for one of these users would be incorrect.
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ return true
+ }
+ return false
+}
+
+func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ ch := net.channels.Value(target)
+
+ ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
+ defer cancel()
+
+ targetCM := net.casemap(target)
+ history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit)
+ if err != nil {
+ dc.logger.Printf("failed to send backlog for %q: %v", target, err)
+ return
+ }
+
+ dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ if ch != nil && ch.Detached {
+ if net.detachedMessageNeedsRelay(ch, msg) {
+ dc.relayDetachedMessage(net, msg)
+ }
+ } else {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, net))
+ }
+ }
+ })
+}
+
+func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return
+ }
+
+ sender := msg.Prefix.Name
+ target, text := msg.Params[0], msg.Params[1]
+ if net.isHighlight(msg) {
+ sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ } else {
+ sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ }
+}
+
+func (dc *downstreamConn) runUntilRegistered() error {
+ ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout)
+ defer cancel()
+
+ // Close the connection with an error if the deadline is exceeded
+ go func() {
+ <-ctx.Done()
+ if err := ctx.Err(); err == context.DeadlineExceeded {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "ERROR",
+ Params: []string{"Connection registration timed out"},
+ })
+ dc.Close()
+ }
+ }()
+
+ for !dc.registered {
+ msg, err := dc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read IRC command: %w", err)
+ }
+
+ err = dc.handleMessage(ctx, msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
+ }
+ }
+
+ return nil
+}
+
+func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "PING":
+ var source, destination string
+ if err := parseMessageParams(msg, &source); err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ destination = msg.Params[1]
+ }
+ hostname := dc.srv.Config().Hostname
+ if destination != "" && destination != hostname {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHSERVER,
+ Params: []string{dc.nick, destination, "No such server"},
+ }}
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "PONG",
+ Params: []string{hostname, source},
+ })
+ return nil
+ case "PONG":
+ if len(msg.Params) == 0 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ token := msg.Params[len(msg.Params)-1]
+ dc.handlePong(token)
+ case "USER":
+ return ircError{&irc.Message{
+ Command: irc.ERR_ALREADYREGISTERED,
+ Params: []string{dc.nick, "You may not reregister"},
+ }}
+ case "NICK":
+ var rawNick string
+ if err := parseMessageParams(msg, &rawNick); err != nil {
+ return err
+ }
+
+ nick := rawNick
+ var upstream *upstreamConn
+ if dc.upstream() == nil {
+ uc, unmarshaledNick, err := dc.unmarshalEntity(nick)
+ if err == nil { // NICK nick/network: NICK only on a specific upstream
+ upstream = uc
+ nick = unmarshaledNick
+ }
+ }
+
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, rawNick, "contains illegal characters"},
+ }}
+ }
+ if casemapASCII(nick) == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"},
+ }}
+ }
+
+ var err error
+ dc.forEachNetwork(func(n *network) {
+ if err != nil || (upstream != nil && upstream.network != n) {
+ return
+ }
+ n.Nick = nick
+ err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network)
+ })
+ if err != nil {
+ return err
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ if upstream != nil && upstream != uc {
+ return
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ })
+
+ if dc.upstream() == nil && upstream == nil && dc.nick != nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ dc.nick = nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+ case "SETNAME":
+ var realname string
+ if err := parseMessageParams(msg, &realname); err != nil {
+ return err
+ }
+
+ // If the client just resets to the default, just wipe the per-network
+ // preference
+ storeRealname := realname
+ if realname == dc.user.Realname {
+ storeRealname = ""
+ }
+
+ var storeErr error
+ var needUpdate []Network
+ dc.forEachNetwork(func(n *network) {
+ // We only need to call updateNetwork for upstreams that don't
+ // support setname
+ if uc := n.conn; uc != nil && uc.caps["setname"] {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+
+ n.Realname = storeRealname
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil {
+ dc.logger.Printf("failed to store network realname: %v", err)
+ storeErr = err
+ }
+ return
+ }
+
+ record := n.Network // copy network record because we'll mutate it
+ record.Realname = storeRealname
+ needUpdate = append(needUpdate, record)
+ })
+
+ // Walk the network list as a second step, because updateNetwork
+ // mutates the original list
+ for _, record := range needUpdate {
+ if _, err := dc.user.updateNetwork(ctx, &record); err != nil {
+ dc.logger.Printf("failed to update network realname: %v", err)
+ storeErr = err
+ }
+ }
+ if storeErr != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"},
+ }}
+ }
+
+ if dc.upstream() == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+ }
+ case "JOIN":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var keys []string
+ if len(msg.Params) > 1 {
+ keys = strings.Split(msg.Params[1], ",")
+ }
+
+ for i, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ var key string
+ if len(keys) > i {
+ key = keys[i]
+ }
+
+ if !uc.isChannel(upstreamName) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{name, "Not a channel name"},
+ })
+ continue
+ }
+
+ // Most servers ignore duplicate JOIN messages. We ignore them here
+ // because some clients automatically send JOIN messages in bulk
+ // when reconnecting to the bouncer. We don't want to flood the
+ // upstream connection with these.
+ if !uc.channels.Has(upstreamName) {
+ params := []string{upstreamName}
+ if key != "" {
+ params = append(params, key)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "JOIN",
+ Params: params,
+ })
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ // Don't clear the channel key if there's one set
+ // TODO: add a way to unset the channel key
+ if key != "" {
+ ch.Key = key
+ }
+ uc.network.attach(ctx, ch)
+ } else {
+ ch = &Channel{
+ Name: upstreamName,
+ Key: key,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ }
+ case "PART":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+
+ for _, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if strings.EqualFold(reason, "detach") {
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ uc.network.detach(ch)
+ } else {
+ ch = &Channel{
+ Name: name,
+ Detached: true,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ } else {
+ params := []string{upstreamName}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "PART",
+ Params: params,
+ })
+
+ if err := uc.network.deleteChannel(ctx, upstreamName); err != nil {
+ dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err)
+ }
+ }
+ }
+ case "KICK":
+ var channelStr, userStr string
+ if err := parseMessageParams(msg, &channelStr, &userStr); err != nil {
+ return err
+ }
+
+ channels := strings.Split(channelStr, ",")
+ users := strings.Split(userStr, ",")
+
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+
+ if len(channels) != 1 && len(channels) != len(users) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_BADCHANMASK,
+ Params: []string{dc.nick, channelStr, "Bad channel mask"},
+ }}
+ }
+
+ for i, user := range users {
+ var channel string
+ if len(channels) == 1 {
+ channel = channels[0]
+ } else {
+ channel = channels[i]
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ params := []string{upstreamChannel, upstreamUser}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "KICK",
+ Params: params,
+ })
+ }
+ case "MODE":
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+
+ var modeStr string
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ if casemapASCII(name) == dc.nickCM {
+ if modeStr != "" {
+ if uc := dc.upstream(); uc != nil {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: []string{uc.nick, modeStr},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_UMODEUNKNOWNFLAG,
+ Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"},
+ })
+ }
+ } else {
+ var userMode string
+ if uc := dc.upstream(); uc != nil {
+ userMode = string(uc.modes)
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + userMode},
+ })
+ }
+ return nil
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if !uc.isChannel(upstreamName) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERSDONTMATCH,
+ Params: []string{dc.nick, "Cannot change mode for other users"},
+ }}
+ }
+
+ if modeStr != "" {
+ params := []string{upstreamName, modeStr}
+ params = append(params, msg.Params[2:]...)
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: params,
+ })
+ } else {
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "No such channel"},
+ }}
+ }
+
+ if ch.modes == nil {
+ // we haven't received the initial RPL_CHANNELMODEIS yet
+ // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS
+ return nil
+ }
+
+ modeStr, modeParams := ch.modes.Format()
+ params := []string{dc.nick, name, modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ if ch.creationTime != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, name, ch.creationTime},
+ })
+ }
+ }
+ case "TOPIC":
+ var channel string
+ if err := parseMessageParams(msg, &channel); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ if len(msg.Params) > 1 { // setting topic
+ topic := msg.Params[1]
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "TOPIC",
+ Params: []string{upstreamName, topic},
+ })
+ } else { // getting topic
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, upstreamName, "No such channel"},
+ }}
+ }
+ sendTopic(dc, ch)
+ }
+ case "LIST":
+ network := dc.network
+ if network == nil && len(msg.Params) > 0 {
+ var err error
+ network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0])
+ if err != nil {
+ return err
+ }
+ }
+ if network == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"},
+ })
+ return nil
+ }
+
+ uc := network.conn
+ if uc == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Disconnected from upstream server"},
+ })
+ return nil
+ }
+
+ uc.enqueueCommand(dc, msg)
+ case "NAMES":
+ if len(msg.Params) == 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, "*", "End of /NAMES list"},
+ })
+ return nil
+ }
+
+ channels := strings.Split(msg.Params[0], ",")
+ for _, channel := range channels {
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(upstreamName)
+ if ch != nil {
+ sendNames(dc, ch)
+ } else {
+ // NAMES on a channel we have not joined, ask upstream
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NAMES",
+ Params: []string{upstreamName},
+ })
+ }
+ }
+ // For WHOX docs, see:
+ // - http://faerion.sourceforge.net/doc/irc/whox.var
+ // - https://github.com/quakenet/snircd/blob/master/doc/readme.who
+ // Note, many features aren't widely implemented, such as flags and mask2
+ case "WHO":
+ if len(msg.Params) == 0 {
+ // TODO: support WHO without parameters
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, "*", "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // Clients will use the first mask to match RPL_ENDOFWHO
+ endOfWhoToken := msg.Params[0]
+
+ // TODO: add support for WHOX mask2
+ mask := msg.Params[0]
+ var options string
+ if len(msg.Params) > 1 {
+ options = msg.Params[1]
+ }
+
+ optionsParts := strings.SplitN(options, "%", 2)
+ // TODO: add support for WHOX flags in optionsParts[0]
+ var fields, whoxToken string
+ if len(optionsParts) == 2 {
+ optionsParts := strings.SplitN(optionsParts[1], ",", 2)
+ fields = strings.ToLower(optionsParts[0])
+ if len(optionsParts) == 2 && strings.Contains(fields, "t") {
+ whoxToken = optionsParts[1]
+ }
+ }
+
+ // TODO: support mixed bouncer/upstream WHO queries
+ maskCM := casemapASCII(mask)
+ if dc.network == nil && maskCM == dc.nickCM {
+ // TODO: support AWAY (H/G) in self WHO reply
+ flags := "H"
+ if dc.user.Admin {
+ flags += "*"
+ }
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: dc.user.Username,
+ Hostname: dc.hostname,
+ Server: dc.srv.Config().Hostname,
+ Nickname: dc.nick,
+ Flags: flags,
+ Account: dc.user.Username,
+ Realname: dc.realname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+ if maskCM == serviceNickCM {
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: servicePrefix.User,
+ Hostname: servicePrefix.Host,
+ Server: dc.srv.Config().Hostname,
+ Nickname: serviceNick,
+ Flags: "H*",
+ Account: serviceNick,
+ Realname: serviceRealname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // TODO: properly support WHO masks
+ uc, upstreamMask, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ params := []string{upstreamMask}
+ if options != "" {
+ params = append(params, options)
+ }
+
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "WHO",
+ Params: params,
+ })
+ case "WHOIS":
+ if len(msg.Params) == 0 {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NONICKNAMEGIVEN,
+ Params: []string{dc.nick, "No nickname given"},
+ }}
+ }
+
+ var target, mask string
+ if len(msg.Params) == 1 {
+ target = ""
+ mask = msg.Params[0]
+ } else {
+ target = msg.Params[0]
+ mask = msg.Params[1]
+ }
+ // TODO: support multiple WHOIS users
+ if i := strings.IndexByte(mask, ','); i >= 0 {
+ mask = mask[:i]
+ }
+
+ if dc.network == nil && casemapASCII(mask) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"},
+ })
+ if dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, dc.nick, "is a bouncer administrator"},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, dc.nick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+ if casemapASCII(mask) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, serviceNick, "is the bouncer service"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, serviceNick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+
+ // TODO: support WHOIS masks
+ uc, upstreamNick, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ var params []string
+ if target != "" {
+ if target == mask { // WHOIS nick nick
+ params = []string{upstreamNick, upstreamNick}
+ } else {
+ params = []string{target, upstreamNick}
+ }
+ } else {
+ params = []string{upstreamNick}
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "WHOIS",
+ Params: params,
+ })
+ case "PRIVMSG", "NOTICE":
+ var targetsStr, text string
+ if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) {
+ // "$" means a server mask follows. If it's the bouncer's
+ // hostname, broadcast the message to all bouncer users.
+ if !dc.user.Admin {
+ return ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_BADMASK,
+ Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"},
+ }}
+ }
+
+ dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text)
+
+ broadcastTags := tags.Copy()
+ broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ broadcastMsg := &irc.Message{
+ Tags: broadcastTags,
+ Prefix: servicePrefix,
+ Command: msg.Command,
+ Params: []string{name, text},
+ }
+ dc.srv.forEachUser(func(u *user) {
+ u.events <- eventBroadcast{broadcastMsg}
+ })
+ continue
+ }
+
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ continue
+ }
+
+ if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM {
+ if dc.caps["echo-message"] {
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ dc.SendMessage(&irc.Message{
+ Tags: echoTags,
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ }
+ handleServicePRIVMSG(ctx, dc, text)
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" {
+ dc.handleNickServPRIVMSG(ctx, uc, text)
+ }
+
+ unmarshaledText := text
+ if uc.isChannel(upstreamName) {
+ unmarshaledText = dc.unmarshalText(uc, text)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: msg.Command,
+ Params: []string{upstreamName, unmarshaledText},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: msg.Command,
+ Params: []string{upstreamName, text},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "TAGMSG":
+ var targetsStr string
+ if err := parseMessageParams(msg, &targetsStr); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: "TAGMSG",
+ Params: []string{name},
+ })
+ continue
+ }
+
+ if casemapASCII(name) == serviceNickCM {
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+ if _, ok := uc.caps["message-tags"]; !ok {
+ continue
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "INVITE":
+ var user, channel string
+ if err := parseMessageParams(msg, &user, &channel); err != nil {
+ return err
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "INVITE",
+ Params: []string{upstreamUser, upstreamChannel},
+ })
+ case "AUTHENTICATE":
+ // Post-connection-registration AUTHENTICATE is unsupported in
+ // multi-upstream mode, or if the upstream doesn't support SASL
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Upstream network authentication not supported"},
+ }}
+ }
+
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ }
+
+ if credentials != nil {
+ if uc.saslClient != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Another authentication attempt is already in progress"},
+ })
+ return nil
+ }
+
+ uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername)
+ uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword)
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"PLAIN"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ // Check number of params here, since we'll use that to save the
+ // credentials on command success
+ if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["draft/account-registration"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"},
+ }}
+ }
+
+ uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0])
+ uc.enqueueCommand(dc, msg)
+ case "MONITOR":
+ // MONITOR is unsupported in multi-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "+", "-":
+ var targets string
+ if err := parseMessageParams(msg, nil, &targets); err != nil {
+ return err
+ }
+ for _, target := range strings.Split(targets, ",") {
+ if subcommand == "+" {
+ // Hard limit, just to avoid having downstreams fill our map
+ if len(dc.monitored.innerMap) >= 1000 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_MONLISTFULL,
+ Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"},
+ })
+ continue
+ }
+
+ dc.monitored.SetValue(target, nil)
+
+ if uc.monitored.Has(target) {
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ } else {
+ dc.monitored.Delete(target)
+ }
+ }
+ uc.updateMonitor()
+ case "C": // clear
+ dc.monitored = newCasemapMap(0)
+ uc.updateMonitor()
+ case "L": // list
+ // TODO: be less lazy and pack the list
+ for _, entry := range dc.monitored.innerMap {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MONLIST,
+ Params: []string{dc.nick, entry.originalKey},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFMONLIST,
+ Params: []string{dc.nick, "End of MONITOR list"},
+ })
+ case "S": // status
+ // TODO: be less lazy and pack the lists
+ for _, entry := range dc.monitored.innerMap {
+ target := entry.originalKey
+
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ case "CHATHISTORY":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+ var target, limitStr string
+ var boundsStr [2]string
+ switch subcommand {
+ case "AFTER", "BEFORE", "LATEST":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil {
+ return err
+ }
+ case "BETWEEN":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ case "TARGETS":
+ if dc.network == nil {
+ // Either an unbound bouncer network, in which case we should return no targets,
+ // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet.
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+ if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ default:
+ // TODO: support AROUND
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"},
+ }}
+ }
+
+ // We don't save history for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+
+ store, ok := dc.user.msgStore.(chatHistoryMessageStore)
+ if !ok {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{dc.nick, "CHATHISTORY", "Unknown command"},
+ }}
+ }
+
+ network, entity, err := dc.unmarshalEntityNetwork(target)
+ if err != nil {
+ return err
+ }
+ entity = network.casemap(entity)
+
+ // TODO: support msgid criteria
+ var bounds [2]time.Time
+ bounds[0] = parseChatHistoryBound(boundsStr[0])
+ if subcommand == "LATEST" && boundsStr[0] == "*" {
+ bounds[0] = time.Now()
+ } else if bounds[0].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"},
+ }}
+ }
+
+ if boundsStr[1] != "" {
+ bounds[1] = parseChatHistoryBound(boundsStr[1])
+ if bounds[1].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"},
+ }}
+ }
+ }
+
+ limit, err := strconv.Atoi(limitStr)
+ if err != nil || limit < 0 || limit > chatHistoryLimit {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"},
+ }}
+ }
+
+ eventPlayback := dc.caps["draft/event-playback"]
+
+ var history []*irc.Message
+ switch subcommand {
+ case "BEFORE", "LATEST":
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback)
+ case "AFTER":
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback)
+ case "BETWEEN":
+ if bounds[0].Before(bounds[1]) {
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ } else {
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ }
+ case "TARGETS":
+ // TODO: support TARGETS in multi-upstream mode
+ targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback)
+ if err != nil {
+ dc.logger.Printf("failed fetching targets for chathistory: %v", err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"},
+ }}
+ }
+
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {
+ for _, target := range targets {
+ if ch := network.channels.Value(target.Name); ch != nil && ch.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "CHATHISTORY",
+ Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)},
+ })
+ }
+ })
+
+ return nil
+ }
+ if err != nil {
+ dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err)
+ return newChatHistoryError(subcommand, target)
+ }
+
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, network))
+ }
+ })
+ case "READ":
+ var target, criteria string
+ if err := parseMessageParams(msg, &target); err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"},
+ }}
+ }
+ if len(msg.Params) > 1 {
+ criteria = msg.Params[1]
+ }
+
+ // We don't save read receipts for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{target, "*"},
+ })
+ return nil
+ }
+
+ uc, entity, err := dc.unmarshalEntity(target)
+ if err != nil {
+ return err
+ }
+ entityCM := uc.network.casemap(entity)
+
+ r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ } else if r == nil {
+ r = &ReadReceipt{
+ Target: entityCM,
+ }
+ }
+
+ broadcast := false
+ if len(criteria) > 0 {
+ // TODO: support msgid criteria
+ criteriaParts := strings.SplitN(criteria, "=", 2)
+ if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"},
+ }}
+ }
+
+ timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1])
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"},
+ }}
+ }
+ now := time.Now()
+ if timestamp.After(now) {
+ timestamp = now
+ }
+ if r.Timestamp.Before(timestamp) {
+ r.Timestamp = timestamp
+ if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil {
+ dc.logger.Printf("failed to store receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ }
+ broadcast = true
+ }
+ }
+
+ timestampStr := "*"
+ if !r.Timestamp.IsZero() {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ uc.forEachDownstream(func(d *downstreamConn) {
+ if broadcast || dc.id == d.id {
+ d.SendMessage(&irc.Message{
+ Prefix: d.prefix(),
+ Command: "READ",
+ Params: []string{d.marshalEntity(uc.network, entity), timestampStr},
+ })
+ }
+ })
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"},
+ }}
+ case "LISTNETWORKS":
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ case "ADDNETWORK":
+ var attrsStr string
+ if err := parseMessageParams(msg, nil, &attrsStr); err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ record := &Network{Nick: dc.nick, Enabled: true}
+ if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)},
+ })
+ case "CHANGENETWORK":
+ var idStr, attrsStr string
+ if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ _, err = dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"CHANGENETWORK", idStr},
+ })
+ case "DELNETWORK":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"DELNETWORK", idStr},
+ })
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"},
+ }}
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+
+ // Only forward unknown commands in single-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, msg)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) {
+ username, password, ok := parseNickServCredentials(text, uc.nick)
+ if ok {
+ uc.network.autoSaveSASLPlain(ctx, username, password)
+ }
+}
+
+func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
+ fields := strings.Fields(text)
+ if len(fields) < 2 {
+ return "", "", false
+ }
+ cmd := strings.ToUpper(fields[0])
+ params := fields[1:]
+ switch cmd {
+ case "REGISTER":
+ username = nick
+ password = params[0]
+ case "IDENTIFY":
+ if len(params) == 1 {
+ username = nick
+ password = params[0]
+ } else {
+ username = params[0]
+ password = params[1]
+ }
+ case "SET":
+ if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") {
+ username = nick
+ password = params[1]
+ }
+ default:
+ return "", "", false
+ }
+ return username, password, true
+}
--- /dev/null
+module marisa.chaotic.ninja/suika
+
+go 1.20
+
+require (
+ git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99
+ git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9
+ github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead
+ github.com/lib/pq v1.10.7
+ golang.org/x/crypto v0.7.0
+ golang.org/x/term v0.6.0
+ golang.org/x/time v0.3.0
+ gopkg.in/irc.v3 v3.1.4
+ modernc.org/sqlite v1.21.0
+)
+
+require (
+ github.com/dustin/go-humanize v1.0.0 // indirect
+ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
+ github.com/google/uuid v1.3.0 // indirect
+ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
+ github.com/mattn/go-isatty v0.0.16 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/stretchr/testify v1.8.0 // indirect
+ golang.org/x/mod v0.3.0 // indirect
+ golang.org/x/sys v0.6.0 // indirect
+ golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect
+ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+ gopkg.in/yaml.v2 v2.4.0 // indirect
+ lukechampine.com/uint128 v1.2.0 // indirect
+ modernc.org/cc/v3 v3.40.0 // indirect
+ modernc.org/ccgo/v3 v3.16.13 // indirect
+ modernc.org/libc v1.22.3 // indirect
+ modernc.org/mathutil v1.5.0 // indirect
+ modernc.org/memory v1.5.0 // indirect
+ modernc.org/opt v0.1.3 // indirect
+ modernc.org/strutil v1.1.3 // indirect
+ modernc.org/token v1.0.1 // indirect
+)
--- /dev/null
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao=
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U=
+git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
+github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
+github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
+github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
+github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
+golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
+golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
+golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
+golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
+golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc=
+gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI=
+lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
+modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw=
+modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0=
+modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw=
+modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY=
+modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk=
+modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM=
+modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY=
+modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw=
+modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
+modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
+modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
+modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
+modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
+modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
+modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow=
+modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI=
+modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
+modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw=
+modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws=
+modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg=
+modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
+modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE=
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+ "unicode"
+ "unicode/utf8"
+
+ "gopkg.in/irc.v3"
+)
+
+const (
+ rpl_statsping = "246"
+ rpl_localusers = "265"
+ rpl_globalusers = "266"
+ rpl_creationtime = "329"
+ rpl_topicwhotime = "333"
+ rpl_whospcrpl = "354"
+ rpl_whoisaccount = "330"
+ err_invalidcapcmd = "410"
+)
+
+const (
+ maxMessageLength = 512
+ maxMessageParams = 15
+ maxSASLLength = 400
+)
+
+// The server-time layout, as defined in the IRCv3 spec.
+const serverTimeLayout = "2006-01-02T15:04:05.000Z"
+
+func formatServerTime(t time.Time) string {
+ return t.UTC().Format(serverTimeLayout)
+}
+
+type userModes string
+
+func (ms userModes) Has(c byte) bool {
+ return strings.IndexByte(string(ms), c) >= 0
+}
+
+func (ms *userModes) Add(c byte) {
+ if !ms.Has(c) {
+ *ms += userModes(c)
+ }
+}
+
+func (ms *userModes) Del(c byte) {
+ i := strings.IndexByte(string(*ms), c)
+ if i >= 0 {
+ *ms = (*ms)[:i] + (*ms)[i+1:]
+ }
+}
+
+func (ms *userModes) Apply(s string) error {
+ var plusMinus byte
+ for i := 0; i < len(s); i++ {
+ switch c := s[i]; c {
+ case '+', '-':
+ plusMinus = c
+ default:
+ switch plusMinus {
+ case '+':
+ ms.Add(c)
+ case '-':
+ ms.Del(c)
+ default:
+ return fmt.Errorf("malformed modestring %q: missing plus/minus", s)
+ }
+ }
+ }
+ return nil
+}
+
+type channelModeType byte
+
+// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message
+const (
+ // modes that add or remove an address to or from a list
+ modeTypeA channelModeType = iota
+ // modes that change a setting on a channel, and must always have a parameter
+ modeTypeB
+ // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset
+ modeTypeC
+ // modes that change a setting on a channel, and must not have a parameter
+ modeTypeD
+)
+
+var stdChannelModes = map[byte]channelModeType{
+ 'b': modeTypeA, // ban list
+ 'e': modeTypeA, // ban exception list
+ 'I': modeTypeA, // invite exception list
+ 'k': modeTypeB, // channel key
+ 'l': modeTypeC, // channel user limit
+ 'i': modeTypeD, // channel is invite-only
+ 'm': modeTypeD, // channel is moderated
+ 'n': modeTypeD, // channel has no external messages
+ 's': modeTypeD, // channel is secret
+ 't': modeTypeD, // channel has protected topic
+}
+
+type channelModes map[byte]string
+
+// applyChannelModes parses a mode string and mode arguments from a MODE message,
+// and applies the corresponding channel mode and user membership changes on that channel.
+//
+// If ch.modes is nil, channel modes are not updated.
+//
+// needMarshaling is a list of indexes of mode arguments that represent entities
+// that must be marshaled when sent downstream.
+func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) {
+ needMarshaling = make(map[int]struct{}, len(arguments))
+ nextArgument := 0
+ var plusMinus byte
+outer:
+ for i := 0; i < len(modeStr); i++ {
+ mode := modeStr[i]
+ if mode == '+' || mode == '-' {
+ plusMinus = mode
+ continue
+ }
+ if plusMinus != '+' && plusMinus != '-' {
+ return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr)
+ }
+
+ for _, membership := range ch.conn.availableMemberships {
+ if membership.Mode == mode {
+ if nextArgument >= len(arguments) {
+ return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
+ }
+ member := arguments[nextArgument]
+ m := ch.Members.Value(member)
+ if m != nil {
+ if plusMinus == '+' {
+ m.Add(ch.conn.availableMemberships, membership)
+ } else {
+ // TODO: for upstreams without multi-prefix, query the user modes again
+ m.Remove(membership)
+ }
+ }
+ needMarshaling[nextArgument] = struct{}{}
+ nextArgument++
+ continue outer
+ }
+ }
+
+ mt, ok := ch.conn.availableChannelModes[mode]
+ if !ok {
+ continue
+ }
+ if mt == modeTypeA {
+ nextArgument++
+ } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') {
+ if plusMinus == '+' {
+ var argument string
+ // some sentitive arguments (such as channel keys) can be omitted for privacy
+ // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages)
+ if nextArgument < len(arguments) {
+ argument = arguments[nextArgument]
+ }
+ if ch.modes != nil {
+ ch.modes[mode] = argument
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ nextArgument++
+ } else if mt == modeTypeC || mt == modeTypeD {
+ if plusMinus == '+' {
+ if ch.modes != nil {
+ ch.modes[mode] = ""
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ }
+ }
+ return needMarshaling, nil
+}
+
+func (cm channelModes) Format() (modeString string, parameters []string) {
+ var modesWithValues strings.Builder
+ var modesWithoutValues strings.Builder
+ parameters = make([]string, 0, 16)
+ for mode, value := range cm {
+ if value != "" {
+ modesWithValues.WriteString(string(mode))
+ parameters = append(parameters, value)
+ } else {
+ modesWithoutValues.WriteString(string(mode))
+ }
+ }
+ modeString = "+" + modesWithValues.String() + modesWithoutValues.String()
+ return
+}
+
+const stdChannelTypes = "#&+!"
+
+type channelStatus byte
+
+const (
+ channelPublic channelStatus = '='
+ channelSecret channelStatus = '@'
+ channelPrivate channelStatus = '*'
+)
+
+func parseChannelStatus(s string) (channelStatus, error) {
+ if len(s) > 1 {
+ return 0, fmt.Errorf("invalid channel status %q: more than one character", s)
+ }
+ switch cs := channelStatus(s[0]); cs {
+ case channelPublic, channelSecret, channelPrivate:
+ return cs, nil
+ default:
+ return 0, fmt.Errorf("invalid channel status %q: unknown status", s)
+ }
+}
+
+type membership struct {
+ Mode byte
+ Prefix byte
+}
+
+var stdMemberships = []membership{
+ {'q', '~'}, // founder
+ {'a', '&'}, // protected
+ {'o', '@'}, // operator
+ {'h', '%'}, // halfop
+ {'v', '+'}, // voice
+}
+
+// memberships always sorted by descending membership rank
+type memberships []membership
+
+func (m *memberships) Add(availableMemberships []membership, newMembership membership) {
+ l := *m
+ i := 0
+ for _, availableMembership := range availableMemberships {
+ if i >= len(l) {
+ break
+ }
+ if l[i] == availableMembership {
+ if availableMembership == newMembership {
+ // we already have this membership
+ return
+ }
+ i++
+ continue
+ }
+ if availableMembership == newMembership {
+ break
+ }
+ }
+ // insert newMembership at i
+ l = append(l, membership{})
+ copy(l[i+1:], l[i:])
+ l[i] = newMembership
+ *m = l
+}
+
+func (m *memberships) Remove(oldMembership membership) {
+ l := *m
+ for i, currentMembership := range l {
+ if currentMembership == oldMembership {
+ *m = append(l[:i], l[i+1:]...)
+ return
+ }
+ }
+}
+
+func (m memberships) Format(dc *downstreamConn) string {
+ if !dc.caps["multi-prefix"] {
+ if len(m) == 0 {
+ return ""
+ }
+ return string(m[0].Prefix)
+ }
+ prefixes := make([]byte, len(m))
+ for i, membership := range m {
+ prefixes[i] = membership.Prefix
+ }
+ return string(prefixes)
+}
+
+func parseMessageParams(msg *irc.Message, out ...*string) error {
+ if len(msg.Params) < len(out) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ for i := range out {
+ if out[i] != nil {
+ *out[i] = msg.Params[i]
+ }
+ }
+ return nil
+}
+
+func copyClientTags(tags irc.Tags) irc.Tags {
+ t := make(irc.Tags, len(tags))
+ for k, v := range tags {
+ if strings.HasPrefix(k, "+") {
+ t[k] = v
+ }
+ }
+ return t
+}
+
+type batch struct {
+ Type string
+ Params []string
+ Outer *batch // if not-nil, this batch is nested in Outer
+ Label string
+}
+
+func join(channels, keys []string) []*irc.Message {
+ // Put channels with a key first
+ js := joinSorter{channels, keys}
+ sort.Sort(&js)
+
+ // Two spaces because there are three words (JOIN, channels and keys)
+ maxLength := maxMessageLength - (len("JOIN") + 2)
+
+ var msgs []*irc.Message
+ var channelsBuf, keysBuf strings.Builder
+ for i, channel := range channels {
+ key := keys[i]
+
+ n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel)
+ if key != "" {
+ n += 1 + len(key)
+ }
+
+ if channelsBuf.Len() > 0 && n > maxLength {
+ // No room for the new channel in this message
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ channelsBuf.Reset()
+ keysBuf.Reset()
+ }
+
+ if channelsBuf.Len() > 0 {
+ channelsBuf.WriteByte(',')
+ }
+ channelsBuf.WriteString(channel)
+ if key != "" {
+ if keysBuf.Len() > 0 {
+ keysBuf.WriteByte(',')
+ }
+ keysBuf.WriteString(key)
+ }
+ }
+ if channelsBuf.Len() > 0 {
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ }
+
+ return msgs
+}
+
+func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message {
+ maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text
+
+ var msgs []*irc.Message
+ for len(tokens) > 0 {
+ var msgTokens []string
+ if len(tokens) > maxTokens {
+ msgTokens = tokens[:maxTokens]
+ tokens = tokens[maxTokens:]
+ } else {
+ msgTokens = tokens
+ tokens = nil
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ISUPPORT,
+ Params: append(append([]string{nick}, msgTokens...), "are supported"),
+ })
+ }
+
+ return msgs
+}
+
+func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message {
+ var msgs []*irc.Message
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTDSTART,
+ Params: []string{nick, fmt.Sprintf("- Message of the Day -")},
+ })
+
+ for _, l := range strings.Split(motd, "\n") {
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTD,
+ Params: []string{nick, l},
+ })
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ENDOFMOTD,
+ Params: []string{nick, "End of /MOTD command."},
+ })
+
+ return msgs
+}
+
+func generateMonitor(subcmd string, targets []string) []*irc.Message {
+ maxLength := maxMessageLength - len("MONITOR "+subcmd+" ")
+
+ var msgs []*irc.Message
+ var buf []string
+ n := 0
+ for _, target := range targets {
+ if n+len(target)+1 > maxLength {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ buf = buf[:0]
+ n = 0
+ }
+
+ buf = append(buf, target)
+ n += len(target) + 1
+ }
+
+ if len(buf) > 0 {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ }
+
+ return msgs
+}
+
+type joinSorter struct {
+ channels []string
+ keys []string
+}
+
+func (js *joinSorter) Len() int {
+ return len(js.channels)
+}
+
+func (js *joinSorter) Less(i, j int) bool {
+ if (js.keys[i] != "") != (js.keys[j] != "") {
+ // Only one of the channels has a key
+ return js.keys[i] != ""
+ }
+ return js.channels[i] < js.channels[j]
+}
+
+func (js *joinSorter) Swap(i, j int) {
+ js.channels[i], js.channels[j] = js.channels[j], js.channels[i]
+ js.keys[i], js.keys[j] = js.keys[j], js.keys[i]
+}
+
+// parseCTCPMessage parses a CTCP message. CTCP is defined in
+// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02
+func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
+ if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 {
+ return "", "", false
+ }
+ text := msg.Params[1]
+
+ if !strings.HasPrefix(text, "\x01") {
+ return "", "", false
+ }
+ text = strings.Trim(text, "\x01")
+
+ words := strings.SplitN(text, " ", 2)
+ cmd = strings.ToUpper(words[0])
+ if len(words) > 1 {
+ params = words[1]
+ }
+
+ return cmd, params, true
+}
+
+type casemapping func(string) string
+
+func casemapNone(name string) string {
+ return name
+}
+
+// CasemapASCII of name is the canonical representation of name according to the
+// ascii casemapping.
+func casemapASCII(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459 of name is the canonical representation of name according to the
+// rfc1459 casemapping.
+func casemapRFC1459(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ } else if r == '~' {
+ nameBytes[i] = '^'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459Strict of name is the canonical representation of name
+// according to the rfc1459-strict casemapping.
+func casemapRFC1459Strict(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ }
+ }
+ return string(nameBytes)
+}
+
+func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
+ switch tokenValue {
+ case "ascii":
+ casemap = casemapASCII
+ case "rfc1459":
+ casemap = casemapRFC1459
+ case "rfc1459-strict":
+ casemap = casemapRFC1459Strict
+ default:
+ return nil, false
+ }
+ return casemap, true
+}
+
+func partialCasemap(higher casemapping, name string) string {
+ nameFullyCM := []byte(higher(name))
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') {
+ nameBytes[i] = nameFullyCM[i]
+ }
+ }
+ return string(nameBytes)
+}
+
+type casemapMap struct {
+ innerMap map[string]casemapEntry
+ casemap casemapping
+}
+
+type casemapEntry struct {
+ originalKey string
+ value interface{}
+}
+
+func newCasemapMap(size int) casemapMap {
+ return casemapMap{
+ innerMap: make(map[string]casemapEntry, size),
+ casemap: casemapNone,
+ }
+}
+
+func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return "", false
+ }
+ return entry.originalKey, true
+}
+
+func (cm *casemapMap) Has(name string) bool {
+ _, ok := cm.innerMap[cm.casemap(name)]
+ return ok
+}
+
+func (cm *casemapMap) Len() int {
+ return len(cm.innerMap)
+}
+
+func (cm *casemapMap) SetValue(name string, value interface{}) {
+ nameCM := cm.casemap(name)
+ entry, ok := cm.innerMap[nameCM]
+ if !ok {
+ cm.innerMap[nameCM] = casemapEntry{
+ originalKey: name,
+ value: value,
+ }
+ return
+ }
+ entry.value = value
+ cm.innerMap[nameCM] = entry
+}
+
+func (cm *casemapMap) Delete(name string) {
+ delete(cm.innerMap, cm.casemap(name))
+}
+
+func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
+ cm.casemap = newCasemap
+ newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
+ for _, entry := range cm.innerMap {
+ newInnerMap[cm.casemap(entry.originalKey)] = entry
+ }
+ cm.innerMap = newInnerMap
+}
+
+type upstreamChannelCasemapMap struct{ casemapMap }
+
+func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*upstreamChannel)
+}
+
+type channelCasemapMap struct{ casemapMap }
+
+func (cm *channelCasemapMap) Value(name string) *Channel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*Channel)
+}
+
+type membershipsCasemapMap struct{ casemapMap }
+
+func (cm *membershipsCasemapMap) Value(name string) *memberships {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*memberships)
+}
+
+type deliveredCasemapMap struct{ casemapMap }
+
+func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(deliveredClientMap)
+}
+
+type monitorCasemapMap struct{ casemapMap }
+
+func (cm *monitorCasemapMap) Value(name string) (online bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return false
+ }
+ return entry.value.(bool)
+}
+
+func isWordBoundary(r rune) bool {
+ switch r {
+ case '-', '_', '|': // inspired from weechat.look.highlight_regex
+ return false
+ default:
+ return !unicode.IsLetter(r) && !unicode.IsNumber(r)
+ }
+}
+
+func isHighlight(text, nick string) bool {
+ for {
+ i := strings.Index(text, nick)
+ if i < 0 {
+ return false
+ }
+
+ left, _ := utf8.DecodeLastRuneInString(text[:i])
+ right, _ := utf8.DecodeRuneInString(text[i+len(nick):])
+ if isWordBoundary(left) && isWordBoundary(right) {
+ return true
+ }
+
+ text = text[i+len(nick):]
+ }
+}
+
+// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound.
+// The zero time is returned on error.
+func parseChatHistoryBound(param string) time.Time {
+ parts := strings.SplitN(param, "=", 2)
+ if len(parts) != 2 {
+ return time.Time{}
+ }
+ switch parts[0] {
+ case "timestamp":
+ timestamp, err := time.Parse(serverTimeLayout, parts[1])
+ if err != nil {
+ return time.Time{}
+ }
+ return timestamp
+ default:
+ return time.Time{}
+ }
+}
+
+// whoxFields is the list of all WHOX field letters, by order of appearance in
+// RPL_WHOSPCRPL messages.
+var whoxFields = []byte("tcuihsnfdlaor")
+
+type whoxInfo struct {
+ Token string
+ Username string
+ Hostname string
+ Server string
+ Nickname string
+ Flags string
+ Account string
+ Realname string
+}
+
+func (info *whoxInfo) get(field byte) string {
+ switch field {
+ case 't':
+ return info.Token
+ case 'c':
+ return "*"
+ case 'u':
+ return info.Username
+ case 'i':
+ return "255.255.255.255"
+ case 'h':
+ return info.Hostname
+ case 's':
+ return info.Server
+ case 'n':
+ return info.Nickname
+ case 'f':
+ return info.Flags
+ case 'd':
+ return "0"
+ case 'l': // idle time
+ return "0"
+ case 'a':
+ account := "0" // WHOX uses "0" to mean "no account"
+ if info.Account != "" && info.Account != "*" {
+ account = info.Account
+ }
+ return account
+ case 'o':
+ return "0"
+ case 'r':
+ return info.Realname
+ }
+ return ""
+}
+
+func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message {
+ if fields == "" {
+ return &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname},
+ }
+ }
+
+ fieldSet := make(map[byte]bool)
+ for i := 0; i < len(fields); i++ {
+ fieldSet[fields[i]] = true
+ }
+
+ var values []string
+ for _, field := range whoxFields {
+ if !fieldSet[field] {
+ continue
+ }
+ values = append(values, info.get(field))
+ }
+
+ return &irc.Message{
+ Prefix: prefix,
+ Command: rpl_whospcrpl,
+ Params: append([]string{nick}, values...),
+ }
+}
+
+var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C")
+
+func encodeISUPPORT(s string) string {
+ return isupportEncoder.Replace(s)
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func TestIsHighlight(t *testing.T) {
+ nick := "SojuUser"
+ testCases := []struct {
+ name string
+ text string
+ hl bool
+ }{
+ {"noContains", "hi there Soju User!", false},
+ {"middle", "hi there SojuUser!", true},
+ {"start", "SojuUser: how are you doing?", true},
+ {"end", "maybe ask SojuUser", true},
+ {"inWord", "but OtherSojuUserSan is a different nick", false},
+ {"startWord", "and OtherSojuUser is another different nick", false},
+ {"endWord", "and SojuUserSan is yet a different nick", false},
+ {"underscore", "and SojuUser_san has nothing to do with me", false},
+ {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false},
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ hl := isHighlight(tc.text, nick)
+ if hl != tc.hl {
+ t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl)
+ }
+ })
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+// messageStore is a per-user store for IRC messages.
+type messageStore interface {
+ Close() error
+ // LastMsgID queries the last message ID for the given network, entity and
+ // date. The message ID returned may not refer to a valid message, but can be
+ // used in history queries.
+ LastMsgID(network *Network, entity string, t time.Time) (string, error)
+ // LoadLatestID queries the latest non-event messages for the given network,
+ // entity and date, up to a count of limit messages, sorted from oldest to newest.
+ LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
+ Append(network *Network, entity string, msg *irc.Message) (id string, err error)
+}
+
+type chatHistoryTarget struct {
+ Name string
+ LatestMessage time.Time
+}
+
+// chatHistoryMessageStore is a message store that supports chat history
+// operations.
+type chatHistoryMessageStore interface {
+ messageStore
+
+ // ListTargets lists channels and nicknames by time of the latest message.
+ // It returns up to limit targets, starting from start and ending on end,
+ // both excluded. end may be before or after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
+ // LoadBeforeTime loads up to limit messages before start down to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is before start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+ // LoadBeforeTime loads up to limit messages after start up to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+}
+
+type msgIDType uint
+
+const (
+ msgIDNone msgIDType = iota
+ msgIDMemory
+ msgIDFS
+)
+
+const msgIDVersion uint = 0
+
+type msgIDHeader struct {
+ Version uint
+ Network bare.Int
+ Target string
+ Type msgIDType
+}
+
+type msgIDBody interface {
+ msgIDType() msgIDType
+}
+
+func formatMsgID(netID int64, target string, body msgIDBody) string {
+ var buf bytes.Buffer
+ w := bare.NewWriter(&buf)
+
+ header := msgIDHeader{
+ Version: msgIDVersion,
+ Network: bare.Int(netID),
+ Target: target,
+ Type: body.msgIDType(),
+ }
+ if err := bare.MarshalWriter(w, &header); err != nil {
+ panic(err)
+ }
+ if err := bare.MarshalWriter(w, body); err != nil {
+ panic(err)
+ }
+ return base64.RawURLEncoding.EncodeToString(buf.Bytes())
+}
+
+func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) {
+ b, err := base64.RawURLEncoding.DecodeString(s)
+ if err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ r := bare.NewReader(bytes.NewReader(b))
+
+ var header msgIDHeader
+ if err := bare.UnmarshalBareReader(r, &header); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ if header.Version != msgIDVersion {
+ return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion)
+ }
+
+ if body != nil {
+ typ := body.msgIDType()
+ if header.Type != typ {
+ return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ)
+ }
+
+ if err := bare.UnmarshalBareReader(r, body); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+ }
+
+ return int64(header.Network), header.Target, nil
+}
--- /dev/null
+package suika
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ fsMessageStoreMaxFiles = 20
+ fsMessageStoreMaxTries = 100
+)
+
+func escapeFilename(unsafe string) (safe string) {
+ if unsafe == "." {
+ return "-"
+ } else if unsafe == ".." {
+ return "--"
+ } else {
+ return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe)
+ }
+}
+
+type date struct {
+ Year, Month, Day int
+}
+
+func newDate(t time.Time) date {
+ year, month, day := t.Date()
+ return date{year, int(month), day}
+}
+
+func (d date) Time() time.Time {
+ return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local)
+}
+
+type fsMsgID struct {
+ Date date
+ Offset bare.Int
+}
+
+func (fsMsgID) msgIDType() msgIDType {
+ return msgIDFS
+}
+
+func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) {
+ var id fsMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", time.Time{}, 0, err
+ }
+ return netID, entity, id.Date.Time(), int64(id.Offset), nil
+}
+
+func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string {
+ id := fsMsgID{
+ Date: newDate(t),
+ Offset: bare.Int(offset),
+ }
+ return formatMsgID(netID, entity, &id)
+}
+
+type fsMessageStoreFile struct {
+ *os.File
+ lastUse time.Time
+}
+
+// fsMessageStore is a per-user on-disk store for IRC messages.
+//
+// It mimicks the ZNC log layout and format. See the ZNC source:
+// https://github.com/znc/znc/blob/master/modules/log.cpp
+type fsMessageStore struct {
+ root string
+ user *User
+
+ // Write-only files used by Append
+ files map[string]*fsMessageStoreFile // indexed by entity
+}
+
+var _ messageStore = (*fsMessageStore)(nil)
+var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
+
+func newFSMessageStore(root string, user *User) *fsMessageStore {
+ return &fsMessageStore{
+ root: filepath.Join(root, escapeFilename(user.Username)),
+ user: user,
+ files: make(map[string]*fsMessageStoreFile),
+ }
+}
+
+func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
+ year, month, day := t.Date()
+ filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
+ return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
+}
+
+// nextMsgID queries the message ID for the next message to be written to f.
+func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
+ offset, err := f.Seek(0, io.SeekEnd)
+ if err != nil {
+ return "", fmt.Errorf("failed to query next FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, offset), nil
+}
+
+func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ p := ms.logPath(network, entity, t)
+ fi, err := os.Stat(p)
+ if os.IsNotExist(err) {
+ return formatFSMsgID(network.ID, entity, t, -1), nil
+ } else if err != nil {
+ return "", fmt.Errorf("failed to query last FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
+}
+
+func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ s := formatMessage(msg)
+ if s == "" {
+ return "", nil
+ }
+
+ var t time.Time
+ if tag, ok := msg.Tags["time"]; ok {
+ var err error
+ t, err = time.Parse(serverTimeLayout, string(tag))
+ if err != nil {
+ return "", fmt.Errorf("failed to parse message time tag: %v", err)
+ }
+ t = t.In(time.Local)
+ } else {
+ t = time.Now()
+ }
+
+ f := ms.files[entity]
+
+ // TODO: handle non-monotonic clock behaviour
+ path := ms.logPath(network, entity, t)
+ if f == nil || f.Name() != path {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0750); err != nil {
+ return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err)
+ }
+
+ ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640)
+ if err != nil {
+ return "", fmt.Errorf("failed to open message log file %q: %v", path, err)
+ }
+
+ if f != nil {
+ f.Close()
+ }
+ f = &fsMessageStoreFile{File: ff}
+ ms.files[entity] = f
+ }
+
+ f.lastUse = time.Now()
+
+ if len(ms.files) > fsMessageStoreMaxFiles {
+ entities := make([]string, 0, len(ms.files))
+ for name := range ms.files {
+ entities = append(entities, name)
+ }
+ sort.Slice(entities, func(i, j int) bool {
+ a, b := entities[i], entities[j]
+ return ms.files[a].lastUse.Before(ms.files[b].lastUse)
+ })
+ entities = entities[0 : len(entities)-fsMessageStoreMaxFiles]
+ for _, name := range entities {
+ ms.files[name].Close()
+ delete(ms.files, name)
+ }
+ }
+
+ msgID, err := nextFSMsgID(network, entity, t, f.File)
+ if err != nil {
+ return "", fmt.Errorf("failed to generate message ID: %v", err)
+ }
+
+ _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
+ if err != nil {
+ return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err)
+ }
+
+ return msgID, nil
+}
+
+func (ms *fsMessageStore) Close() error {
+ var closeErr error
+ for _, f := range ms.files {
+ if err := f.Close(); err != nil {
+ closeErr = fmt.Errorf("failed to close message store: %v", err)
+ }
+ }
+ return closeErr
+}
+
+// formatMessage formats a message log line. It assumes a well-formed IRC
+// message.
+func formatMessage(msg *irc.Message) string {
+ switch strings.ToUpper(msg.Command) {
+ case "NICK":
+ return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0])
+ case "JOIN":
+ return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host)
+ case "PART":
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+ return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "KICK":
+ nick := msg.Params[1]
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+ return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason)
+ case "QUIT":
+ var reason string
+ if len(msg.Params) > 0 {
+ reason = msg.Params[0]
+ }
+ return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "TOPIC":
+ var topic string
+ if len(msg.Params) > 1 {
+ topic = msg.Params[1]
+ }
+ return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic)
+ case "MODE":
+ return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " "))
+ case "NOTICE":
+ return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1])
+ case "PRIVMSG":
+ if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" {
+ return fmt.Sprintf("* %s %s", msg.Prefix.Name, params)
+ } else {
+ return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1])
+ }
+ default:
+ return ""
+ }
+}
+
+func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
+ var hour, minute, second int
+ _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
+ if err != nil {
+ return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err)
+ }
+ line = line[11:]
+
+ var cmd string
+ var prefix *irc.Prefix
+ var params []string
+ if events && strings.HasPrefix(line, "*** ") {
+ parts := strings.SplitN(line[4:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ switch parts[0] {
+ case "Joins:", "Parts:", "Quits:":
+ args := strings.SplitN(parts[1], " ", 3)
+ if len(args) < 2 {
+ return nil, time.Time{}, nil
+ }
+ nick := args[0]
+ mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ maskParts := strings.SplitN(mask, "@", 2)
+ if len(maskParts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ prefix = &irc.Prefix{
+ Name: nick,
+ User: maskParts[0],
+ Host: maskParts[1],
+ }
+ var reason string
+ if len(args) > 2 {
+ reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")")
+ }
+ switch parts[0] {
+ case "Joins:":
+ cmd = "JOIN"
+ params = []string{entity}
+ case "Parts:":
+ cmd = "PART"
+ if reason != "" {
+ params = []string{entity, reason}
+ } else {
+ params = []string{entity}
+ }
+ case "Quits:":
+ cmd = "QUIT"
+ if reason != "" {
+ params = []string{reason}
+ }
+ }
+ default:
+ nick := parts[0]
+ rem := parts[1]
+ if r := strings.TrimPrefix(rem, "is now known as "); r != rem {
+ cmd = "NICK"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = []string{r}
+ } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem {
+ args := strings.SplitN(r, " ", 2)
+ if len(args) != 2 {
+ return nil, time.Time{}, nil
+ }
+ cmd = "KICK"
+ prefix = &irc.Prefix{
+ Name: args[0],
+ }
+ reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ params = []string{entity, nick}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem {
+ cmd = "TOPIC"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'")
+ params = []string{entity, topic}
+ } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem {
+ cmd = "MODE"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = append([]string{entity}, strings.Split(r, " ")...)
+ } else {
+ return nil, time.Time{}, nil
+ }
+ }
+ } else {
+ var sender, text string
+ if strings.HasPrefix(line, "<") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[1:], "> ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "-") {
+ cmd = "NOTICE"
+ parts := strings.SplitN(line[1:], "- ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "* ") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[2:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01"
+ } else {
+ return nil, time.Time{}, nil
+ }
+
+ prefix = &irc.Prefix{Name: sender}
+ if entity == sender {
+ // This is a direct message from a user to us. We don't store own
+ // our nickname in the logs, so grab it from the network settings.
+ // Not very accurate since this may not match our nick at the time
+ // the message was received, but we can't do a lot better.
+ entity = GetNick(ms.user, network)
+ }
+ params = []string{entity, text}
+ }
+
+ year, month, day := ref.Date()
+ t := time.Date(year, month, day, hour, minute, second, 0, time.Local)
+
+ msg := &irc.Message{
+ Tags: map[string]irc.TagValue{
+ "time": irc.TagValue(formatServerTime(t)),
+ },
+ Prefix: prefix,
+ Command: cmd,
+ Params: params,
+ }
+ return msg, t, nil
+}
+
+func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages before ref: %v", err)
+ }
+ defer f.Close()
+
+ historyRing := make([]*irc.Message, limit)
+ cur := 0
+
+ sc := bufio.NewScanner(f)
+
+ if afterOffset >= 0 {
+ if _, err := f.Seek(afterOffset, io.SeekStart); err != nil {
+ return nil, nil
+ }
+ sc.Scan() // skip till next newline
+ }
+
+ for sc.Scan() {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(end) {
+ continue
+ } else if !t.Before(ref) {
+ break
+ }
+
+ historyRing[cur%limit] = msg
+ cur++
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err())
+ }
+
+ n := limit
+ if cur < limit {
+ n = cur
+ }
+ start := (cur - n + limit) % limit
+
+ if start+n <= limit { // ring doesnt wrap
+ return historyRing[start : start+n], nil
+ } else { // ring wraps
+ history := make([]*irc.Message, n)
+ r := copy(history, historyRing[start:])
+ copy(history[r:], historyRing[:n-r])
+ return history, nil
+ }
+}
+
+func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages after ref: %v", err)
+ }
+ defer f.Close()
+
+ var history []*irc.Message
+ sc := bufio.NewScanner(f)
+ for sc.Scan() && len(history) < limit {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(ref) {
+ continue
+ } else if !t.Before(end) {
+ break
+ }
+
+ history = append(history, msg)
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err())
+ }
+
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ history := make([]*irc.Message, limit)
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) {
+ buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ var history []*irc.Message
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) {
+ buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ history = append(history, buf...)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location())
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ var afterTime time.Time
+ var afterOffset int64
+ if id != "" {
+ var idNet int64
+ var idEntity string
+ var err error
+ idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+ if idNet != network.ID || idEntity != entity {
+ return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity")
+ }
+ }
+
+ history := make([]*irc.Message, limit)
+ t := time.Now()
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) {
+ var offset int64 = -1
+ if afterOffset >= 0 && truncateDay(t).Equal(afterTime) {
+ offset = afterOffset
+ }
+
+ buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := t.Date()
+ t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
+ root, err := os.Open(rootPath)
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+
+ // The returned targets are escaped, and there is no way to un-escape
+ // TODO: switch to ReadDir (Go 1.16+)
+ targetNames, err := root.Readdirnames(0)
+ root.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ var targets []chatHistoryTarget
+ for _, target := range targetNames {
+ // target is already escaped here
+ targetPath := filepath.Join(rootPath, target)
+ targetDir, err := os.Open(targetPath)
+ if err != nil {
+ return nil, err
+ }
+
+ entries, err := targetDir.Readdir(0)
+ targetDir.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ // We use mtime here, which may give imprecise or incorrect results
+ var t time.Time
+ for _, entry := range entries {
+ if entry.ModTime().After(t) {
+ t = entry.ModTime()
+ }
+ }
+
+ // The timestamps we get from logs have second granularity
+ t = truncateSecond(t)
+
+ // Filter out targets that don't fullfil the time bounds
+ if !isTimeBetween(t, start, end) {
+ continue
+ }
+
+ targets = append(targets, chatHistoryTarget{
+ Name: target,
+ LatestMessage: t,
+ })
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ // Sort targets by latest message time, backwards or forwards depending on
+ // the order of the time bounds
+ sort.Slice(targets, func(i, j int) bool {
+ t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage
+ if start.Before(end) {
+ return t1.Before(t2)
+ } else {
+ return !t1.Before(t2)
+ }
+ })
+
+ // Truncate the result if necessary
+ if len(targets) > limit {
+ targets = targets[:limit]
+ }
+
+ return targets, nil
+}
+
+func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
+ oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
+ newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
+ // Avoid loosing data by overwriting an existing directory
+ if _, err := os.Stat(newDir); err == nil {
+ return fmt.Errorf("destination %q already exists", newDir)
+ }
+ return os.Rename(oldDir, newDir)
+}
+
+func truncateDay(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, 0, 0, 0, 0, t.Location())
+}
+
+func truncateSecond(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location())
+}
+
+func isTimeBetween(t, start, end time.Time) bool {
+ if end.Before(start) {
+ end, start = start, end
+ }
+ return start.Before(t) && t.Before(end)
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const messageRingBufferCap = 4096
+
+type memoryMsgID struct {
+ Seq bare.Uint
+}
+
+func (memoryMsgID) msgIDType() msgIDType {
+ return msgIDMemory
+}
+
+func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
+ var id memoryMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", 0, err
+ }
+ return netID, entity, uint64(id.Seq), nil
+}
+
+func formatMemoryMsgID(netID int64, entity string, seq uint64) string {
+ id := memoryMsgID{bare.Uint(seq)}
+ return formatMsgID(netID, entity, &id)
+}
+
+type ringBufferKey struct {
+ networkID int64
+ entity string
+}
+
+type memoryMessageStore struct {
+ buffers map[ringBufferKey]*messageRingBuffer
+}
+
+var _ messageStore = (*memoryMessageStore)(nil)
+
+func newMemoryMessageStore() *memoryMessageStore {
+ return &memoryMessageStore{
+ buffers: make(map[ringBufferKey]*messageRingBuffer),
+ }
+}
+
+func (ms *memoryMessageStore) Close() error {
+ ms.buffers = nil
+ return nil
+}
+
+func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ return rb
+ }
+ rb := newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ return rb
+}
+
+func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ var seq uint64
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ seq = rb.cur
+ }
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ // Only append these messages, because LoadLatestID shouldn't return
+ // other kinds of message.
+ default:
+ return "", nil
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ rb = newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ }
+
+ seq := rb.Append(msg)
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ _, _, seq, err := parseMemoryMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ return nil, nil
+ }
+
+ return rb.LoadLatestSeq(seq, limit)
+}
+
+type messageRingBuffer struct {
+ buf []*irc.Message
+ cur uint64
+}
+
+func newMessageRingBuffer(capacity int) *messageRingBuffer {
+ return &messageRingBuffer{
+ buf: make([]*irc.Message, capacity),
+ cur: 1,
+ }
+}
+
+func (rb *messageRingBuffer) cap() uint64 {
+ return uint64(len(rb.buf))
+}
+
+func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 {
+ seq := rb.cur
+ i := int(seq % rb.cap())
+ rb.buf[i] = msg
+ rb.cur++
+ return seq
+}
+
+func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) {
+ if seq > rb.cur {
+ return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur)
+ } else if seq == rb.cur {
+ return nil, nil
+ }
+
+ // The query excludes the message with the sequence number seq
+ diff := rb.cur - seq - 1
+ if diff > rb.cap() {
+ // We dropped diff - cap entries
+ diff = rb.cap()
+ }
+ if int(diff) > limit {
+ diff = uint64(limit)
+ }
+
+ l := make([]*irc.Message, int(diff))
+ for i := 0; i < int(diff); i++ {
+ j := int((rb.cur - diff + uint64(i)) % rb.cap())
+ l[i] = rb.buf[j]
+ }
+
+ return l, nil
+}
--- /dev/null
+//go:build !go1.16
+// +build !go1.16
+
+package suika
+
+import (
+ "strings"
+)
+
+func isErrClosed(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "use of closed network connection")
+}
--- /dev/null
+//go:build go1.16
+// +build go1.16
+
+package suika
+
+import (
+ "errors"
+ "net"
+)
+
+func isErrClosed(err error) bool {
+ return errors.Is(err, net.ErrClosed)
+}
--- /dev/null
+package suika
+
+import (
+ "math/rand"
+ "time"
+)
+
+// backoffer implements a simple exponential backoff.
+type backoffer struct {
+ min, max, jitter time.Duration
+ n int64
+}
+
+func newBackoffer(min, max, jitter time.Duration) *backoffer {
+ return &backoffer{min: min, max: max, jitter: jitter}
+}
+
+func (b *backoffer) Reset() {
+ b.n = 0
+}
+
+func (b *backoffer) Next() time.Duration {
+ if b.n == 0 {
+ b.n = 1
+ return 0
+ }
+
+ d := time.Duration(b.n) * b.min
+ if d > b.max {
+ d = b.max
+ } else {
+ b.n *= 2
+ }
+
+ if b.jitter != 0 {
+ d += time.Duration(rand.Int63n(int64(b.jitter)))
+ }
+
+ return d
+}
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+desc="A drunk IRC bouncer"
+rcvar="suika_enable"
+
+: ${suika_user="ircd"}
+
+command="%%PREFIX%%/bin/suika"
+pidfile="/var/run/suika.pid"
+required_files="%%PREFIX%%/etc/suika/config"
+
+start_cmd="suika_start"
+
+suika_start() {
+ /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files}
+}
+
+load_rc_config "$name"
+run_rc_command "$1"
--- /dev/null
+# $TheSupernovaDuo$
+cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+user: ircd
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+rcvar="${name}"
+command="%%PREFIX/bin/${name}"
+command_args="--config %%PREFIX%%/etc/suika/config"
+pidfile="/var/run/${name}.pid"
+start_cmd="${name}_start"
+
+suika_start() {
+ printf "Starting %s..." "${name}"
+ ${command} ${command_args}
+ pgrep -n ${name} > ${pidfile}
+}
+
+load_rc_config ${name}
+run_rc_command "$1"
+
+
--- /dev/null
+#!/bin/ksh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+daemon="%%PREFIX%%/bin/suika"
+daemon_args="--config %%PREFIX%%/etc/suika/config"
+
+. /etc/rc.d/rc.subr
+
+rc_bg=YES
+
+rc_cmd "$1"
--- /dev/null
+# $TheSupernovaDuo$
+# vim: ft=confini
+[Unit]
+Description=A drunk IRC bouncer
+After=network.target
+Wants=network.target
+StartLimitBurst=5
+StartLimitIntervalSec=1
+[Service]
+Type=simple
+Restart=on-abnormal
+RestartSec=1
+User=suika
+ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+[Install]
+WantedBy=multi-user.target
--- /dev/null
+package suika
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "runtime/debug"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+// TODO: make configurable
+var (
+ retryConnectMinDelay = time.Minute
+ retryConnectMaxDelay = 10 * time.Minute
+ retryConnectJitter = time.Minute
+ connectTimeout = 15 * time.Second
+ writeTimeout = 10 * time.Second
+ upstreamMessageDelay = 2 * time.Second
+ upstreamMessageBurst = 10
+ backlogTimeout = 10 * time.Second
+ handleDownstreamMessageTimeout = 10 * time.Second
+ downstreamRegisterTimeout = 30 * time.Second
+ chatHistoryLimit = 1000
+ backlogLimit = 4000
+)
+
+type Logger interface {
+ Printf(format string, v ...interface{})
+ Debugf(format string, v ...interface{})
+}
+
+type logger struct {
+ *log.Logger
+ debug bool
+}
+
+func (l logger) Debugf(format string, v ...interface{}) {
+ if !l.debug {
+ return
+ }
+ l.Logger.Printf(format, v...)
+}
+
+func NewLogger(out io.Writer, debug bool) Logger {
+ return logger{
+ Logger: log.New(log.Writer(), "", log.LstdFlags),
+ debug: debug,
+ }
+}
+
+type prefixLogger struct {
+ logger Logger
+ prefix string
+}
+
+var _ Logger = (*prefixLogger)(nil)
+
+func (l *prefixLogger) Printf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Printf("%v"+format, v...)
+}
+
+func (l *prefixLogger) Debugf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Debugf("%v"+format, v...)
+}
+
+type int64Gauge struct {
+ v int64 // atomic
+}
+
+func (g *int64Gauge) Add(delta int64) {
+ atomic.AddInt64(&g.v, delta)
+}
+
+func (g *int64Gauge) Value() int64 {
+ return atomic.LoadInt64(&g.v)
+}
+
+func (g *int64Gauge) Float64() float64 {
+ return float64(g.Value())
+}
+
+type retryListener struct {
+ net.Listener
+ Logger Logger
+
+ delay time.Duration
+}
+
+func (ln *retryListener) Accept() (net.Conn, error) {
+ for {
+ conn, err := ln.Listener.Accept()
+ if ne, ok := err.(net.Error); ok && ne.Temporary() {
+ if ln.delay == 0 {
+ ln.delay = 5 * time.Millisecond
+ } else {
+ ln.delay *= 2
+ }
+ if max := 1 * time.Second; ln.delay > max {
+ ln.delay = max
+ }
+ if ln.Logger != nil {
+ ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err)
+ }
+ time.Sleep(ln.delay)
+ } else {
+ ln.delay = 0
+ return conn, err
+ }
+ }
+}
+
+type Config struct {
+ Hostname string
+ Title string
+ LogPath string
+ MaxUserNetworks int
+ MultiUpstream bool
+ MOTD string
+ UpstreamUserIPs []*net.IPNet
+}
+
+type Server struct {
+ Logger Logger
+
+ config atomic.Value // *Config
+ db Database
+ stopWG sync.WaitGroup
+
+ lock sync.Mutex
+ listeners map[net.Listener]struct{}
+ users map[string]*user
+}
+
+func NewServer(db Database) *Server {
+ srv := &Server{
+ Logger: NewLogger(log.Writer(), true),
+ db: db,
+ listeners: make(map[net.Listener]struct{}),
+ users: make(map[string]*user),
+ }
+ srv.config.Store(&Config{
+ Hostname: "localhost",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ })
+ return srv
+}
+
+func (s *Server) prefix() *irc.Prefix {
+ return &irc.Prefix{Name: s.Config().Hostname}
+}
+
+func (s *Server) Config() *Config {
+ return s.config.Load().(*Config)
+}
+
+func (s *Server) SetConfig(cfg *Config) {
+ s.config.Store(cfg)
+}
+
+func (s *Server) Start() error {
+ users, err := s.db.ListUsers(context.TODO())
+ if err != nil {
+ return err
+ }
+
+ s.lock.Lock()
+ for i := range users {
+ s.addUserLocked(&users[i])
+ }
+ s.lock.Unlock()
+
+ return nil
+}
+
+func (s *Server) Shutdown() {
+ s.lock.Lock()
+ for ln := range s.listeners {
+ if err := ln.Close(); err != nil {
+ s.Logger.Printf("failed to stop listener: %v", err)
+ }
+ }
+ for _, u := range s.users {
+ u.events <- eventStop{}
+ }
+ s.lock.Unlock()
+
+ s.stopWG.Wait()
+
+ if err := s.db.Close(); err != nil {
+ s.Logger.Printf("failed to close DB: %v", err)
+ }
+}
+
+func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if _, ok := s.users[user.Username]; ok {
+ return nil, fmt.Errorf("user %q already exists", user.Username)
+ }
+
+ err := s.db.StoreUser(ctx, user)
+ if err != nil {
+ return nil, fmt.Errorf("could not create user in db: %v", err)
+ }
+
+ return s.addUserLocked(user), nil
+}
+
+func (s *Server) forEachUser(f func(*user)) {
+ s.lock.Lock()
+ for _, u := range s.users {
+ f(u)
+ }
+ s.lock.Unlock()
+}
+
+func (s *Server) getUser(name string) *user {
+ s.lock.Lock()
+ u := s.users[name]
+ s.lock.Unlock()
+ return u
+}
+
+func (s *Server) addUserLocked(user *User) *user {
+ s.Logger.Printf("starting bouncer for user %q", user.Username)
+ u := newUser(s, user)
+ s.users[u.Username] = u
+
+ s.stopWG.Add(1)
+
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
+ }
+
+ s.lock.Lock()
+ delete(s.users, u.Username)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ u.run()
+ }()
+
+ return u
+}
+
+var lastDownstreamID uint64 = 0
+
+func (s *Server) handle(ic ircConn) {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
+ }
+ }()
+
+ id := atomic.AddUint64(&lastDownstreamID, 1)
+ dc := newDownstreamConn(s, ic, id)
+ if err := dc.runUntilRegistered(); err != nil {
+ if !errors.Is(err, io.EOF) {
+ dc.logger.Printf("%v", err)
+ }
+ } else {
+ dc.user.events <- eventDownstreamConnected{dc}
+ if err := dc.readMessages(dc.user.events); err != nil {
+ dc.logger.Printf("%v", err)
+ }
+ dc.user.events <- eventDownstreamDisconnected{dc}
+ }
+ dc.Close()
+}
+
+func (s *Server) Serve(ln net.Listener) error {
+ ln = &retryListener{
+ Listener: ln,
+ Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())},
+ }
+
+ s.lock.Lock()
+ s.listeners[ln] = struct{}{}
+ s.lock.Unlock()
+
+ s.stopWG.Add(1)
+
+ defer func() {
+ s.lock.Lock()
+ delete(s.listeners, ln)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ for {
+ conn, err := ln.Accept()
+ if isErrClosed(err) {
+ return nil
+ } else if err != nil {
+ return fmt.Errorf("failed to accept connection: %v", err)
+ }
+
+ go s.handle(newNetIRCConn(conn))
+ }
+}
+
+type ServerStats struct {
+ Users int
+ Downstreams int64
+ Upstreams int64
+}
+
+func (s *Server) Stats() *ServerStats {
+ var stats ServerStats
+ s.lock.Lock()
+ stats.Users = len(s.users)
+ s.lock.Unlock()
+ return &stats
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "net"
+ "testing"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+var testServerPrefix = &irc.Prefix{Name: "suika-test-server"}
+
+const (
+ testUsername = "suika-test-user"
+ testPassword = testUsername
+)
+
+func createTempSqliteDB(t *testing.T) Database {
+ db, err := OpenDB("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+ // :memory: will open a separate database for each new connection. Make
+ // sure the sql package only uses a single connection. An alternative
+ // solution is to use "file::memory:?cache=shared".
+ db.(*SqliteDB).db.SetMaxOpenConns(1)
+ return db
+}
+
+func createTempPostgresDB(t *testing.T) Database {
+ db := &PostgresDB{db: openTempPostgresDB(t)}
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("failed to upgrade PostgreSQL database: %v", err)
+ }
+
+ return db
+}
+
+func createTestUser(t *testing.T, db Database) *User {
+ hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
+ if err != nil {
+ t.Fatalf("failed to generate bcrypt hash: %v", err)
+ }
+
+ record := &User{Username: testUsername, Password: string(hashed)}
+ if err := db.StoreUser(context.Background(), record); err != nil {
+ t.Fatalf("failed to store test user: %v", err)
+ }
+
+ return record
+}
+
+func createTestDownstream(t *testing.T, srv *Server) ircConn {
+ c1, c2 := net.Pipe()
+ go srv.handle(newNetIRCConn(c1))
+ return newNetIRCConn(c2)
+}
+
+func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("failed to create TCP listener: %v", err)
+ }
+
+ network := &Network{
+ Name: "testnet",
+ Addr: "irc://" + ln.Addr().String(),
+ Nick: user.Username,
+ Enabled: true,
+ }
+ if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil {
+ t.Fatalf("failed to store test network: %v", err)
+ }
+
+ return network, ln
+}
+
+func mustAccept(t *testing.T, ln net.Listener) ircConn {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("failed accepting connection: %v", err)
+ }
+ return newNetIRCConn(c)
+}
+
+func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
+ msg, err := c.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message (want %q): %v", cmd, err)
+ }
+ if msg.Command != cmd {
+ t.Fatalf("invalid message received: want %q, got: %v", cmd, msg)
+ }
+ return msg
+}
+
+func registerDownstreamConn(t *testing.T, c ircConn, network *Network) {
+ c.WriteMessage(&irc.Message{
+ Command: "PASS",
+ Params: []string{testPassword},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "NICK",
+ Params: []string{testUsername},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "USER",
+ Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername},
+ })
+
+ expectMessage(t, c, irc.RPL_WELCOME)
+}
+
+func registerUpstreamConn(t *testing.T, c ircConn) {
+ msg := expectMessage(t, c, "CAP")
+ if msg.Params[0] != "LS" {
+ t.Fatalf("invalid CAP LS: got: %v", msg)
+ }
+ msg = expectMessage(t, c, "NICK")
+ nick := msg.Params[0]
+ if nick != testUsername {
+ t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg)
+ }
+ expectMessage(t, c, "USER")
+
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_WELCOME,
+ Params: []string{nick, "Welcome!"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_YOURHOST,
+ Params: []string{nick, "Your host is suika-test-server"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_CREATED,
+ Params: []string{nick, "Who cares when the server was created?"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_MYINFO,
+ Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.ERR_NOMOTD,
+ Params: []string{nick, "No MOTD"},
+ })
+}
+
+func testServer(t *testing.T, db Database) {
+ user := createTestUser(t, db)
+ network, upstream := createTestUpstream(t, db, user)
+ defer upstream.Close()
+
+ srv := NewServer(db)
+ if err := srv.Start(); err != nil {
+ t.Fatalf("failed to start server: %v", err)
+ }
+ defer srv.Shutdown()
+
+ uc := mustAccept(t, upstream)
+ defer uc.Close()
+ registerUpstreamConn(t, uc)
+
+ dc := createTestDownstream(t, srv)
+ defer dc.Close()
+ registerDownstreamConn(t, dc, network)
+
+ noticeText := "This is a very important server notice."
+ uc.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: "NOTICE",
+ Params: []string{testUsername, noticeText},
+ })
+
+ var msg *irc.Message
+ for {
+ var err error
+ msg, err = dc.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message: %v", err)
+ }
+ if msg.Command == "NOTICE" {
+ break
+ }
+ }
+
+ if msg.Params[1] != noticeText {
+ t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg)
+ }
+}
+
+func TestServer(t *testing.T) {
+ t.Run("sqlite", func(t *testing.T) {
+ db := createTempSqliteDB(t)
+ testServer(t, db)
+ })
+
+ t.Run("postgres", func(t *testing.T) {
+ db := createTempPostgresDB(t)
+ testServer(t, db)
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/hex"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ serviceNick = "BouncerServ"
+ serviceNickCM = "bouncerserv"
+ serviceRealname = "suika bouncer service"
+)
+
+// maxRSABits is the maximum number of RSA key bits used when generating a new
+// private key.
+const maxRSABits = 8192
+
+var servicePrefix = &irc.Prefix{
+ Name: serviceNick,
+ User: serviceNick,
+ Host: serviceNick,
+}
+
+type serviceCommandSet map[string]*serviceCommand
+
+type serviceCommand struct {
+ usage string
+ desc string
+ handle func(ctx context.Context, dc *downstreamConn, params []string) error
+ children serviceCommandSet
+ admin bool
+}
+
+func sendServiceNOTICE(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func sendServicePRIVMSG(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "PRIVMSG",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func splitWords(s string) ([]string, error) {
+ var words []string
+ var lastWord strings.Builder
+ escape := false
+ prev := ' '
+ wordDelim := ' '
+
+ for _, r := range s {
+ if escape {
+ // last char was a backslash, write the byte as-is.
+ lastWord.WriteRune(r)
+ escape = false
+ } else if r == '\\' {
+ escape = true
+ } else if wordDelim == ' ' && unicode.IsSpace(r) {
+ // end of last word
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ lastWord.Reset()
+ }
+ } else if r == wordDelim {
+ // wordDelim is either " or ', switch back to
+ // space-delimited words.
+ wordDelim = ' '
+ } else if r == '"' || r == '\'' {
+ if wordDelim == ' ' {
+ // start of (double-)quoted word
+ wordDelim = r
+ } else {
+ // either wordDelim is " and r is ' or vice-versa
+ lastWord.WriteRune(r)
+ }
+ } else {
+ lastWord.WriteRune(r)
+ }
+
+ prev = r
+ }
+
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ }
+
+ if wordDelim != ' ' {
+ return nil, fmt.Errorf("unterminated quoted string")
+ }
+ if escape {
+ return nil, fmt.Errorf("unterminated backslash sequence")
+ }
+
+ return words, nil
+}
+
+func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) {
+ words, err := splitWords(text)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err))
+ return
+ }
+
+ cmd, params, err := serviceCommands.Get(words)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err))
+ return
+ }
+ if cmd.admin && !dc.user.Admin {
+ sendServicePRIVMSG(dc, "error: you must be an admin to use this command")
+ return
+ }
+
+ if cmd.handle == nil {
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ // Pretend the command does not exist if it has neither children nor handler.
+ // This is obviously a bug but it is better to not die anyway.
+ dc.logger.Printf("command without handler and subcommands invoked:", words[0])
+ sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0]))
+ }
+ return
+ }
+
+ if err := cmd.handle(ctx, dc, params); err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err))
+ }
+}
+
+func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) {
+ if len(params) == 0 {
+ return nil, nil, fmt.Errorf("no command specified")
+ }
+
+ name := params[0]
+ params = params[1:]
+
+ cmd, ok := cmds[name]
+ if !ok {
+ for k := range cmds {
+ if !strings.HasPrefix(k, name) {
+ continue
+ }
+ if cmd != nil {
+ return nil, params, fmt.Errorf("command %q is ambiguous", name)
+ }
+ cmd = cmds[k]
+ }
+ }
+ if cmd == nil {
+ return nil, params, fmt.Errorf("command %q not found", name)
+ }
+
+ if len(params) == 0 || len(cmd.children) == 0 {
+ return cmd, params, nil
+ }
+ return cmd.children.Get(params)
+}
+
+func (cmds serviceCommandSet) Names() []string {
+ l := make([]string, 0, len(cmds))
+ for name := range cmds {
+ l = append(l, name)
+ }
+ sort.Strings(l)
+ return l
+}
+
+var serviceCommands serviceCommandSet
+
+func init() {
+ serviceCommands = serviceCommandSet{
+ "help": {
+ usage: "[command]",
+ desc: "print help message",
+ handle: handleServiceHelp,
+ },
+ "network": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "add a new network",
+ handle: handleServiceNetworkCreate,
+ },
+ "status": {
+ desc: "show a list of saved networks and their current status",
+ handle: handleServiceNetworkStatus,
+ },
+ "update": {
+ usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "update a network",
+ handle: handleServiceNetworkUpdate,
+ },
+ "delete": {
+ usage: "[name]",
+ desc: "delete a network",
+ handle: handleServiceNetworkDelete,
+ },
+ "quote": {
+ usage: "[name] <command>",
+ desc: "send a raw line to a network",
+ handle: handleServiceNetworkQuote,
+ },
+ },
+ },
+ "certfp": {
+ children: serviceCommandSet{
+ "generate": {
+ usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]",
+ desc: "generate a new self-signed certificate, defaults to using RSA-3072 key",
+ handle: handleServiceCertFPGenerate,
+ },
+ "fingerprint": {
+ usage: "[-network name]",
+ desc: "show fingerprints of certificate",
+ handle: handleServiceCertFPFingerprints,
+ },
+ },
+ },
+ "sasl": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show SASL status",
+ handle: handleServiceSASLStatus,
+ },
+ "set-plain": {
+ usage: "[-network name] <username> <password>",
+ desc: "set SASL PLAIN credentials",
+ handle: handleServiceSASLSetPlain,
+ },
+ "reset": {
+ usage: "[-network name]",
+ desc: "disable SASL authentication and remove stored credentials",
+ handle: handleServiceSASLReset,
+ },
+ },
+ },
+ "user": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-username <username> -password <password> [-realname <realname>] [-admin]",
+ desc: "create a new suika user",
+ handle: handleUserCreate,
+ admin: true,
+ },
+ "update": {
+ usage: "[-password <password>] [-realname <realname>]",
+ desc: "update the current user",
+ handle: handleUserUpdate,
+ },
+ "delete": {
+ usage: "<username>",
+ desc: "delete a user",
+ handle: handleUserDelete,
+ admin: true,
+ },
+ },
+ },
+ "channel": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show a list of saved channels and their current status",
+ handle: handleServiceChannelStatus,
+ },
+ "update": {
+ usage: "<name> [-relay-detached <default|none|highlight|message>] [-reattach-on <default|none|highlight|message>] [-detach-after <duration>] [-detach-on <default|none|highlight|message>]",
+ desc: "update a channel",
+ handle: handleServiceChannelUpdate,
+ },
+ },
+ },
+ "server": {
+ children: serviceCommandSet{
+ "status": {
+ desc: "show server statistics",
+ handle: handleServiceServerStatus,
+ admin: true,
+ },
+ "notice": {
+ desc: "broadcast a notice to all connected bouncer users",
+ handle: handleServiceServerNotice,
+ admin: true,
+ },
+ },
+ admin: true,
+ },
+ }
+}
+
+func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) {
+ for _, name := range cmds.Names() {
+ cmd := cmds[name]
+ if cmd.admin && !admin {
+ continue
+ }
+ words := append(prefix, name)
+ if len(cmd.children) == 0 {
+ s := strings.Join(words, " ")
+ *l = append(*l, s)
+ } else {
+ appendServiceCommandSetHelp(cmd.children, words, admin, l)
+ }
+ }
+}
+
+func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) > 0 {
+ cmd, rest, err := serviceCommands.Get(params)
+ if err != nil {
+ return err
+ }
+ words := params[:len(params)-len(rest)]
+
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ text := strings.Join(words, " ")
+ if cmd.usage != "" {
+ text += " " + cmd.usage
+ }
+ text += ": " + cmd.desc
+
+ sendServicePRIVMSG(dc, text)
+ }
+ } else {
+ var l []string
+ appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ }
+ return nil
+}
+
+func newFlagSet() *flag.FlagSet {
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.SetOutput(ioutil.Discard)
+ return fs
+}
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+// stringPtrFlag is a flag value populating a string pointer. This allows to
+// disambiguate between a flag that hasn't been set and a flag that has been
+// set to an empty string.
+type stringPtrFlag struct {
+ ptr **string
+}
+
+func (f stringPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return ""
+ }
+ return **f.ptr
+}
+
+func (f stringPtrFlag) Set(s string) error {
+ *f.ptr = &s
+ return nil
+}
+
+type boolPtrFlag struct {
+ ptr **bool
+}
+
+func (f boolPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return "<nil>"
+ }
+ return strconv.FormatBool(**f.ptr)
+}
+
+func (f boolPtrFlag) Set(s string) error {
+ v, err := strconv.ParseBool(s)
+ if err != nil {
+ return err
+ }
+ *f.ptr = &v
+ return nil
+}
+
+func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) {
+ name, params := popArg(params)
+ if name == "" {
+ if dc.network == nil {
+ return nil, params, fmt.Errorf("no network selected, a name argument is required")
+ }
+ return dc.network, params, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, params, fmt.Errorf("unknown network %q", name)
+ }
+ return net, params, nil
+ }
+}
+
+type networkFlagSet struct {
+ *flag.FlagSet
+ Addr, Name, Nick, Username, Pass, Realname *string
+ Enabled *bool
+ ConnectCommands []string
+}
+
+func newNetworkFlagSet() *networkFlagSet {
+ fs := &networkFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
+ fs.Var(stringPtrFlag{&fs.Name}, "name", "")
+ fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
+ fs.Var(stringPtrFlag{&fs.Username}, "username", "")
+ fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
+ fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
+ fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "")
+ fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
+ return fs
+}
+
+func (fs *networkFlagSet) update(network *Network) error {
+ if fs.Addr != nil {
+ if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
+ scheme := addrParts[0]
+ switch scheme {
+ case "ircs", "irc", "unix":
+ default:
+ return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme)
+ }
+ }
+ network.Addr = *fs.Addr
+ }
+ if fs.Name != nil {
+ network.Name = *fs.Name
+ }
+ if fs.Nick != nil {
+ network.Nick = *fs.Nick
+ }
+ if fs.Username != nil {
+ network.Username = *fs.Username
+ }
+ if fs.Pass != nil {
+ network.Pass = *fs.Pass
+ }
+ if fs.Realname != nil {
+ network.Realname = *fs.Realname
+ }
+ if fs.Enabled != nil {
+ network.Enabled = *fs.Enabled
+ }
+ if fs.ConnectCommands != nil {
+ if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
+ network.ConnectCommands = nil
+ } else {
+ for _, command := range fs.ConnectCommands {
+ _, err := irc.ParseMessage(command)
+ if err != nil {
+ return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
+ }
+ }
+ network.ConnectCommands = fs.ConnectCommands
+ }
+ }
+ return nil
+}
+
+func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if fs.Addr == nil {
+ return fmt.Errorf("flag -addr is required")
+ }
+
+ record := &Network{
+ Addr: *fs.Addr,
+ Enabled: true,
+ }
+ if err := fs.update(record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return fmt.Errorf("could not create network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ n := 0
+ for _, net := range dc.user.networks {
+ var statuses []string
+ var details string
+ if uc := net.conn; uc != nil {
+ if dc.nick != uc.nick {
+ statuses = append(statuses, "connected as "+uc.nick)
+ } else {
+ statuses = append(statuses, "connected")
+ }
+ details = fmt.Sprintf("%v channels", uc.channels.Len())
+ } else if !net.Enabled {
+ statuses = append(statuses, "disabled")
+ } else {
+ statuses = append(statuses, "disconnected")
+ if net.lastError != nil {
+ details = net.lastError.Error()
+ }
+ }
+
+ if net == dc.network {
+ statuses = append(statuses, "current")
+ }
+
+ name := net.GetName()
+ if name != net.Addr {
+ name = fmt.Sprintf("%v (%v)", name, net.Addr)
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", "))
+ if details != "" {
+ s += ": " + details
+ }
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, `No network configured, add one with "network create".`)
+ }
+
+ return nil
+}
+
+func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := fs.update(&record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return fmt.Errorf("could not update network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName()))
+ return nil
+}
+
+func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 && len(params) != 2 {
+ return fmt.Errorf("expected one or two arguments")
+ }
+
+ raw := params[len(params)-1]
+ params = params[:len(params)-1]
+
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ uc := net.conn
+ if uc == nil {
+ return fmt.Errorf("network %q is not currently connected", net.GetName())
+ }
+
+ m, err := irc.ParseMessage(raw)
+ if err != nil {
+ return fmt.Errorf("failed to parse command %q: %v", raw, err)
+ }
+ uc.SendMessage(ctx, m)
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName()))
+ return nil
+}
+
+func sendCertfpFingerprints(dc *downstreamConn, cert []byte) {
+ sha1Sum := sha1.Sum(cert)
+ sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
+ sha256Sum := sha256.Sum256(cert)
+ sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))
+ sha512Sum := sha512.Sum512(cert)
+ sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:]))
+}
+
+func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) {
+ if name == "" {
+ if dc.network == nil {
+ return nil, fmt.Errorf("no network selected, -network is required")
+ }
+ return dc.network, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, fmt.Errorf("unknown network %q", name)
+ }
+ return net, nil
+ }
+}
+
+func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+ keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)")
+ bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if *bits <= 0 || *bits > maxRSABits {
+ return fmt.Errorf("invalid value for -bits")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ privKey, cert, err := generateCertFP(*keyType, *bits)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.External.CertBlob = cert
+ net.SASL.External.PrivKeyBlob = privKey
+ net.SASL.Mechanism = "EXTERNAL"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "certificate generated")
+ sendCertfpFingerprints(dc, cert)
+ return nil
+}
+
+func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ if net.SASL.Mechanism != "EXTERNAL" {
+ return fmt.Errorf("CertFP not set up")
+ }
+
+ sendCertfpFingerprints(dc, net.SASL.External.CertBlob)
+ return nil
+}
+
+func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ switch net.SASL.Mechanism {
+ case "PLAIN":
+ sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username))
+ case "EXTERNAL":
+ sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled")
+ case "":
+ sendServicePRIVMSG(dc, "SASL is disabled")
+ }
+
+ if uc := net.conn; uc != nil {
+ if uc.account != "" {
+ sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account))
+ } else {
+ sendServicePRIVMSG(dc, "Unauthenticated on upstream network")
+ }
+ } else {
+ sendServicePRIVMSG(dc, "Disconnected from upstream network")
+ }
+
+ return nil
+}
+
+func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if len(fs.Args()) != 2 {
+ return fmt.Errorf("expected exactly 2 arguments")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = fs.Arg(0)
+ net.SASL.Plain.Password = fs.Arg(1)
+ net.SASL.Mechanism = "PLAIN"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials saved")
+ return nil
+}
+
+func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = ""
+ net.SASL.Plain.Password = ""
+ net.SASL.External.CertBlob = nil
+ net.SASL.External.PrivKeyBlob = nil
+ net.SASL.Mechanism = ""
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials reset")
+ return nil
+}
+
+func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ username := fs.String("username", "", "")
+ password := fs.String("password", "", "")
+ realname := fs.String("realname", "", "")
+ admin := fs.Bool("admin", false, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if *username == "" {
+ return fmt.Errorf("flag -username is required")
+ }
+ if *password == "" {
+ return fmt.Errorf("flag -password is required")
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+
+ user := &User{
+ Username: *username,
+ Password: string(hashed),
+ Realname: *realname,
+ Admin: *admin,
+ }
+ if _, err := dc.srv.createUser(ctx, user); err != nil {
+ return fmt.Errorf("could not create user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username))
+ return nil
+}
+
+func popArg(params []string) (string, []string) {
+ if len(params) > 0 && !strings.HasPrefix(params[0], "-") {
+ return params[0], params[1:]
+ }
+ return "", params
+}
+
+func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ var password, realname *string
+ var admin *bool
+ fs := newFlagSet()
+ fs.Var(stringPtrFlag{&password}, "password", "")
+ fs.Var(stringPtrFlag{&realname}, "realname", "")
+ fs.Var(boolPtrFlag{&admin}, "admin", "")
+
+ username, params := popArg(params)
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if len(fs.Args()) > 0 {
+ return fmt.Errorf("unexpected argument")
+ }
+
+ var hashed *string
+ if password != nil {
+ hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+ hashedStr := string(hashedBytes)
+ hashed = &hashedStr
+ }
+
+ if username != "" && username != dc.user.Username {
+ if !dc.user.Admin {
+ return fmt.Errorf("you must be an admin to update other users")
+ }
+ if realname != nil {
+ return fmt.Errorf("cannot update -realname of other user")
+ }
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ done := make(chan error, 1)
+ event := eventUserUpdate{
+ password: hashed,
+ admin: admin,
+ done: done,
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case u.events <- event:
+ }
+ // TODO: send context to the other side
+ if err := <-done; err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username))
+ } else {
+ // copy the user record because we'll mutate it
+ record := dc.user.User
+
+ if hashed != nil {
+ record.Password = *hashed
+ }
+ if realname != nil {
+ record.Realname = *realname
+ }
+ if admin != nil {
+ return fmt.Errorf("cannot update -admin of own user")
+ }
+
+ if err := dc.user.updateUser(ctx, &record); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username))
+ }
+
+ return nil
+}
+
+func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ username := params[0]
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ u.stop()
+
+ if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil {
+ return fmt.Errorf("failed to delete user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username))
+ return nil
+}
+
+func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ var defaultNetworkName string
+ if dc.network != nil {
+ defaultNetworkName = dc.network.GetName()
+ }
+
+ fs := newFlagSet()
+ networkName := fs.String("network", defaultNetworkName, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ n := 0
+
+ sendNetwork := func(net *network) {
+ var channels []*Channel
+ for _, entry := range net.channels.innerMap {
+ channels = append(channels, entry.value.(*Channel))
+ }
+
+ sort.Slice(channels, func(i, j int) bool {
+ return strings.ReplaceAll(channels[i].Name, "#", "") <
+ strings.ReplaceAll(channels[j].Name, "#", "")
+ })
+
+ for _, ch := range channels {
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+ }
+
+ name := ch.Name
+ if *networkName == "" {
+ name += "/" + net.GetName()
+ }
+
+ var status string
+ if uch != nil {
+ status = "joined"
+ } else if net.conn != nil {
+ status = "parted"
+ } else {
+ status = "disconnected"
+ }
+
+ if ch.Detached {
+ status += ", detached"
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, status)
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+ }
+
+ if *networkName == "" {
+ for _, net := range dc.user.networks {
+ sendNetwork(net)
+ }
+ } else {
+ net := dc.user.getNetwork(*networkName)
+ if net == nil {
+ return fmt.Errorf("unknown network %q", *networkName)
+ }
+ sendNetwork(net)
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, "No channel configured.")
+ }
+
+ return nil
+}
+
+type channelFlagSet struct {
+ *flag.FlagSet
+ RelayDetached, ReattachOn, DetachAfter, DetachOn *string
+}
+
+func newChannelFlagSet() *channelFlagSet {
+ fs := &channelFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "")
+ fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "")
+ fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "")
+ fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "")
+ return fs
+}
+
+func (fs *channelFlagSet) update(channel *Channel) error {
+ if fs.RelayDetached != nil {
+ filter, err := parseFilter(*fs.RelayDetached)
+ if err != nil {
+ return err
+ }
+ channel.RelayDetached = filter
+ }
+ if fs.ReattachOn != nil {
+ filter, err := parseFilter(*fs.ReattachOn)
+ if err != nil {
+ return err
+ }
+ channel.ReattachOn = filter
+ }
+ if fs.DetachAfter != nil {
+ dur, err := time.ParseDuration(*fs.DetachAfter)
+ if err != nil || dur < 0 {
+ return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter)
+ }
+ channel.DetachAfter = dur
+ }
+ if fs.DetachOn != nil {
+ filter, err := parseFilter(*fs.DetachOn)
+ if err != nil {
+ return err
+ }
+ channel.DetachOn = filter
+ }
+ return nil
+}
+
+func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) < 1 {
+ return fmt.Errorf("expected at least one argument")
+ }
+ name := params[0]
+
+ fs := newChannelFlagSet()
+ if err := fs.Parse(params[1:]); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ if err := fs.update(ch); err != nil {
+ return err
+ }
+
+ uc.updateChannelAutoDetach(upstreamName)
+
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ return fmt.Errorf("failed to update channel: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name))
+ return nil
+}
+func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ dbStats, err := dc.user.srv.db.Stats(ctx)
+ if err != nil {
+ return err
+ }
+ serverStats := dc.user.srv.Stats()
+ sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels))
+ return nil
+}
+
+func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ text := params[0]
+
+ dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text)
+
+ broadcastMsg := &irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{"$" + dc.srv.Config().Hostname, text},
+ }
+ var err error
+ sent := 0
+ total := 0
+ dc.srv.forEachUser(func(u *user) {
+ total++
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ case u.events <- eventBroadcast{broadcastMsg}:
+ sent++
+ }
+ })
+
+ dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total)
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total))
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func assertSplit(t *testing.T, input string, expected []string) {
+ actual, err := splitWords(input)
+ if err != nil {
+ t.Errorf("%q: %v", input, err)
+ return
+ }
+ if len(actual) != len(expected) {
+ t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual)
+ return
+ }
+ for i := 0; i < len(actual); i++ {
+ if actual[i] != expected[i] {
+ t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual)
+ }
+ }
+}
+
+func TestSplit(t *testing.T) {
+ assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{
+ "ch",
+ "up",
+ "#suika",
+ "relay-detached",
+ "message",
+ })
+ assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{
+ "net",
+ "update",
+ "\"free\"node",
+ "-pass",
+ "political \"stance\" desu!",
+ "-realname",
+ "",
+ "-nick",
+ "lee",
+ })
+ assertSplit(t, "Omedeto,\\ Yui! ''", []string{
+ "Omedeto, Yui!",
+ "",
+ })
+
+ if _, err := splitWords("end of 'file"); err == nil {
+ t.Errorf("expected error on unterminated single quote")
+ }
+ if _, err := splitWords("end of backquote \\"); err == nil {
+ t.Errorf("expected error on unterminated backquote sequence")
+ }
+}
--- /dev/null
+CREATE TABLE IF NOT EXISTS "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+
+CREATE TABLE IF NOT EXISTS "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255),
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism sasl_mechanism,
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA,
+ sasl_external_key BYTEA,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+CREATE TABLE IF NOT EXISTS "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+CREATE TABLE IF NOT EXISTS "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+);
+
--- /dev/null
+CREATE TABLE IF NOT EXISTS User (
+ id INTEGER PRIMARY KEY,
+ username TEXT NOT NULL UNIQUE,
+ password TEXT,
+ admin INTEGER NOT NULL DEFAULT 0,
+ realname TEXT
+);
+CREATE TABLE IF NOT EXISTS Network (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+CREATE TABLE IF NOT EXISTS Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name TEXT NOT NULL,
+ key TEXT,
+ detached INTEGER NOT NULL DEFAULT 0,
+ detached_internal_msgid TEXT,
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ client TEXT,
+ internal_msgid TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+);
+
+CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+);
+
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto"
+ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "gopkg.in/irc.v3"
+)
+
+// permanentUpstreamCaps is the static list of upstream capabilities always
+// requested when supported.
+var permanentUpstreamCaps = map[string]bool{
+ "account-notify": true,
+ "account-tag": true,
+ "away-notify": true,
+ "batch": true,
+ "extended-join": true,
+ "invite-notify": true,
+ "labeled-response": true,
+ "message-tags": true,
+ "multi-prefix": true,
+ "sasl": true,
+ "server-time": true,
+ "setname": true,
+
+ "draft/account-registration": true,
+ "draft/extended-monitor": true,
+}
+
+type registrationError struct {
+ *irc.Message
+}
+
+func (err registrationError) Error() string {
+ return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason())
+}
+
+func (err registrationError) Reason() string {
+ if len(err.Params) > 0 {
+ return err.Params[len(err.Params)-1]
+ }
+ return err.Command
+}
+
+func (err registrationError) Temporary() bool {
+ // Only return false if we're 100% sure that fixing the error requires a
+ // network configuration change
+ switch err.Command {
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME:
+ return false
+ case "FAIL":
+ return err.Params[1] != "ACCOUNT_REQUIRED"
+ default:
+ return true
+ }
+}
+
+type upstreamChannel struct {
+ Name string
+ conn *upstreamConn
+ Topic string
+ TopicWho *irc.Prefix
+ TopicTime time.Time
+ Status channelStatus
+ modes channelModes
+ creationTime string
+ Members membershipsCasemapMap
+ complete bool
+ detachTimer *time.Timer
+}
+
+func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) {
+ if uc.detachTimer != nil {
+ uc.detachTimer.Stop()
+ uc.detachTimer = nil
+ }
+
+ if dur == 0 {
+ return
+ }
+
+ uc.detachTimer = time.AfterFunc(dur, func() {
+ uc.conn.network.user.events <- eventChannelDetach{
+ uc: uc.conn,
+ name: uc.Name,
+ }
+ })
+}
+
+type pendingUpstreamCommand struct {
+ downstreamID uint64
+ msg *irc.Message
+}
+
+type upstreamConn struct {
+ conn
+
+ network *network
+ user *user
+
+ serverName string
+ availableUserModes string
+ availableChannelModes map[byte]channelModeType
+ availableChannelTypes string
+ availableMemberships []membership
+ isupport map[string]*string
+
+ registered bool
+ nick string
+ nickCM string
+ username string
+ realname string
+ modes userModes
+ channels upstreamChannelCasemapMap
+ supportedCaps map[string]string
+ caps map[string]bool
+ batches map[string]batch
+ away bool
+ account string
+ nextLabelID uint64
+ monitored monitorCasemapMap
+
+ saslClient sasl.Client
+ saslStarted bool
+
+ casemapIsSet bool
+
+ // Queue of commands in progress, indexed by type. The first entry has been
+ // sent to the server and is awaiting reply. The following entries have not
+ // been sent yet.
+ pendingCmds map[string][]pendingUpstreamCommand
+
+ gotMotd bool
+}
+
+func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
+ logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
+
+ dialer := net.Dialer{Timeout: connectTimeout}
+
+ u, err := network.URL()
+ if err != nil {
+ return nil, err
+ }
+
+ var netConn net.Conn
+ switch u.Scheme {
+ case "ircs":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6697"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to TLS server at address %q", addr)
+
+ tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}}
+ if network.SASL.Mechanism == "EXTERNAL" {
+ if network.SASL.External.CertBlob == nil {
+ return nil, fmt.Errorf("missing certificate for authentication")
+ }
+ if network.SASL.External.PrivKeyBlob == nil {
+ return nil, fmt.Errorf("missing private key for authentication")
+ }
+ key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse private key: %v", err)
+ }
+ tlsConfig.Certificates = []tls.Certificate{
+ {
+ Certificate: [][]byte{network.SASL.External.CertBlob},
+ PrivateKey: key.(crypto.PrivateKey),
+ },
+ }
+ logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
+ }
+
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+
+ // Don't do the TLS handshake immediately, because we need to register
+ // the new connection with identd ASAP.
+ netConn = tls.Client(netConn, tlsConfig)
+ case "irc":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6667"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to plain-text server at address %q", addr)
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+ case "irc+unix", "unix":
+ logger.Printf("connecting to Unix socket at path %q", u.Path)
+ netConn, err = dialer.DialContext(ctx, "unix", u.Path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
+ }
+ default:
+ return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
+ }
+
+ options := connOptions{
+ Logger: logger,
+ RateLimitDelay: upstreamMessageDelay,
+ RateLimitBurst: upstreamMessageBurst,
+ }
+
+ uc := &upstreamConn{
+ conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
+ network: network,
+ user: network.user,
+ channels: upstreamChannelCasemapMap{newCasemapMap(0)},
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ batches: make(map[string]batch),
+ availableChannelTypes: stdChannelTypes,
+ availableChannelModes: stdChannelModes,
+ availableMemberships: stdMemberships,
+ isupport: make(map[string]*string),
+ pendingCmds: make(map[string][]pendingUpstreamCommand),
+ monitored: monitorCasemapMap{newCasemapMap(0)},
+ }
+ return uc, nil
+}
+
+func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
+ uc.network.forEachDownstream(f)
+}
+
+func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if id != 0 && id != dc.id {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
+ for _, dc := range uc.user.downstreamConns {
+ if dc.id == id {
+ return dc
+ }
+ }
+ return nil
+}
+
+func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ return nil, fmt.Errorf("unknown channel %q", name)
+ }
+ return ch, nil
+}
+
+func (uc *upstreamConn) isChannel(entity string) bool {
+ return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0]))
+}
+
+func (uc *upstreamConn) isOurNick(nick string) bool {
+ return uc.nickCM == uc.network.casemap(nick)
+}
+
+func (uc *upstreamConn) abortPendingCommands() {
+ for _, l := range uc.pendingCmds {
+ for _, pendingCmd := range l {
+ dc := uc.downstreamByID(pendingCmd.downstreamID)
+ if dc == nil {
+ continue
+ }
+
+ switch pendingCmd.msg.Command {
+ case "LIST":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Command aborted"},
+ })
+ case "WHO":
+ mask := "*"
+ if len(pendingCmd.msg.Params) > 0 {
+ mask = pendingCmd.msg.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "Command aborted"},
+ })
+ case "AUTHENTICATE":
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ case "REGISTER", "VERIFY":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "FAIL",
+ Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"},
+ })
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command))
+ }
+ }
+ }
+
+ uc.pendingCmds = make(map[string][]pendingUpstreamCommand)
+}
+
+func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return
+ }
+ uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg)
+}
+
+func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
+ switch msg.Command {
+ case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY":
+ // Supported
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", msg.Command))
+ }
+
+ uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{
+ downstreamID: dc.id,
+ msg: msg,
+ })
+
+ if len(uc.pendingCmds[msg.Command]) == 1 {
+ uc.sendNextPendingCommand(msg.Command)
+ }
+}
+
+func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return nil, nil
+ }
+
+ pendingCmd := uc.pendingCmds[cmd][0]
+ return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg
+}
+
+func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) {
+ dc, msg := uc.currentPendingCommand(cmd)
+
+ if len(uc.pendingCmds[cmd]) > 0 {
+ copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:])
+ uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1]
+ }
+
+ uc.sendNextPendingCommand(cmd)
+
+ return dc, msg
+}
+
+func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) {
+ for cmd := range uc.pendingCmds {
+ // We can't cancel the currently running command stored in
+ // uc.pendingCmds[cmd][0]
+ for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- {
+ if uc.pendingCmds[cmd][i].downstreamID == downstreamID {
+ uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...)
+ }
+ }
+ }
+}
+
+func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) {
+ memberships := make(memberships, 0, 4)
+ i := 0
+ for _, m := range uc.availableMemberships {
+ if i >= len(s) {
+ break
+ }
+ if s[i] == m.Prefix {
+ memberships = append(memberships, m)
+ i++
+ }
+ }
+ return &memberships, s[i:]
+}
+
+func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ var label string
+ if l, ok := msg.GetTag("label"); ok {
+ label = l
+ delete(msg.Tags, "label")
+ }
+
+ var msgBatch *batch
+ if batchName, ok := msg.GetTag("batch"); ok {
+ b, ok := uc.batches[batchName]
+ if !ok {
+ return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
+ }
+ msgBatch = &b
+ if label == "" {
+ label = msgBatch.Label
+ }
+ delete(msg.Tags, "batch")
+ }
+
+ var downstreamID uint64 = 0
+ if label != "" {
+ var labelOffset uint64
+ n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset)
+ if err == nil && n < 2 {
+ err = errors.New("not enough arguments")
+ }
+ if err != nil {
+ return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
+ }
+ }
+
+ if _, ok := msg.Tags["time"]; !ok {
+ msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ }
+
+ switch msg.Command {
+ case "PING":
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PONG",
+ Params: msg.Params,
+ })
+ return nil
+ case "NOTICE", "PRIVMSG", "TAGMSG":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var entity, text string
+ if msg.Command != "TAGMSG" {
+ if err := parseMessageParams(msg, &entity, &text); err != nil {
+ return err
+ }
+ } else {
+ if err := parseMessageParams(msg, &entity); err != nil {
+ return err
+ }
+ }
+
+ if msg.Prefix.Name == serviceNick {
+ uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg)
+ break
+ }
+ if entity == serviceNick {
+ uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg)
+ break
+ }
+
+ if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
+ uc.produce("", msg, nil)
+ } else { // regular user message
+ target := entity
+ if uc.isOurNick(target) {
+ target = msg.Prefix.Name
+ }
+
+ ch := uc.network.channels.Value(target)
+ if ch != nil && msg.Command != "TAGMSG" {
+ if ch.Detached {
+ uc.handleDetachedMessage(ctx, ch, msg)
+ }
+
+ highlight := uc.network.isHighlight(msg)
+ if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) {
+ uc.updateChannelAutoDetach(target)
+ }
+ }
+
+ uc.produce(target, msg, nil)
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, nil, &subCmd); err != nil {
+ return err
+ }
+ subCmd = strings.ToUpper(subCmd)
+ subParams := msg.Params[2:]
+ switch subCmd {
+ case "LS":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := subParams[len(subParams)-1]
+ more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
+
+ uc.handleSupportedCaps(caps)
+
+ if more {
+ break // wait to receive all capabilities
+ }
+
+ uc.requestCaps()
+
+ if uc.requestSASL() {
+ break // we'll send CAP END after authentication is completed
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ case "ACK", "NAK":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, name := range caps {
+ if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil {
+ return err
+ }
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ case "NEW":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ uc.handleSupportedCaps(subParams[0])
+ uc.requestCaps()
+ case "DEL":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, c := range caps {
+ delete(uc.supportedCaps, c)
+ delete(uc.caps, c)
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+ }
+ case "AUTHENTICATE":
+ if uc.saslClient == nil {
+ return fmt.Errorf("received unexpected AUTHENTICATE message")
+ }
+
+ // TODO: if a challenge is 400 bytes long, buffer it
+ var challengeStr string
+ if err := parseMessageParams(msg, &challengeStr); err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ var challenge []byte
+ if challengeStr != "+" {
+ var err error
+ challenge, err = base64.StdEncoding.DecodeString(challengeStr)
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+ }
+
+ var resp []byte
+ var err error
+ if !uc.saslStarted {
+ _, resp, err = uc.saslClient.Start()
+ uc.saslStarted = true
+ } else {
+ resp, err = uc.saslClient.Next(challenge)
+ }
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ // <= instead of < because we need to send a final empty response if
+ // the last chunk is exactly 400 bytes long
+ for i := 0; i <= len(resp); i += maxSASLLength {
+ j := i + maxSASLLength
+ if j > len(resp) {
+ j = len(resp)
+ }
+
+ chunk := resp[i:j]
+
+ var respStr = "+"
+ if len(chunk) != 0 {
+ respStr = base64.StdEncoding.EncodeToString(chunk)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{respStr},
+ })
+ }
+ case irc.RPL_LOGGEDIN:
+ if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil {
+ return err
+ }
+ uc.logger.Printf("logged in with account %q", uc.account)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.RPL_LOGGEDOUT:
+ uc.account = ""
+ uc.logger.Printf("logged out")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
+ var info string
+ if err := parseMessageParams(msg, nil, &info); err != nil {
+ return err
+ }
+ switch msg.Command {
+ case irc.ERR_NICKLOCKED:
+ uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
+ case irc.ERR_SASLFAIL:
+ uc.logger.Printf("SASL authentication failed: %v", info)
+ case irc.ERR_SASLTOOLONG:
+ uc.logger.Printf("SASL message too long: %v", info)
+ }
+
+ uc.saslClient = nil
+ uc.saslStarted = false
+
+ if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
+ if msg.Command == irc.RPL_SASLSUCCESS {
+ uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword)
+ }
+
+ dc.endSASL(msg)
+ }
+
+ if !uc.registered {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil {
+ if msg.Command == "REGISTER" {
+ var account, password string
+ if err := parseMessageParams(msg, nil, &account); err != nil {
+ return err
+ }
+ if err := parseMessageParams(cmd, nil, nil, &password); err != nil {
+ return err
+ }
+ uc.network.autoSaveSASLPlain(ctx, account, password)
+ }
+
+ dc.SendMessage(msg)
+ }
+ case irc.RPL_WELCOME:
+ if err := parseMessageParams(msg, &uc.nick); err != nil {
+ return err
+ }
+
+ uc.registered = true
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("connection registered with nick %q", uc.nick)
+
+ if uc.network.channels.Len() > 0 {
+ var channels, keys []string
+ for _, entry := range uc.network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, ch.Name)
+ keys = append(keys, ch.Key)
+ }
+
+ for _, msg := range join(channels, keys) {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+ case irc.RPL_MYINFO:
+ if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
+ return err
+ }
+ case irc.RPL_ISUPPORT:
+ if err := parseMessageParams(msg, nil, nil); err != nil {
+ return err
+ }
+
+ var downstreamIsupport []string
+ for _, token := range msg.Params[1 : len(msg.Params)-1] {
+ parameter := token
+ var negate, hasValue bool
+ var value string
+ if strings.HasPrefix(token, "-") {
+ negate = true
+ token = token[1:]
+ } else if i := strings.IndexByte(token, '='); i >= 0 {
+ parameter = token[:i]
+ value = token[i+1:]
+ hasValue = true
+ }
+
+ if hasValue {
+ uc.isupport[parameter] = &value
+ } else if !negate {
+ uc.isupport[parameter] = nil
+ } else {
+ delete(uc.isupport, parameter)
+ }
+
+ var err error
+ switch parameter {
+ case "CASEMAPPING":
+ casemap, ok := parseCasemappingToken(value)
+ if !ok {
+ casemap = casemapRFC1459
+ }
+ uc.network.updateCasemapping(casemap)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.casemapIsSet = true
+ case "CHANMODES":
+ if !negate {
+ err = uc.handleChanModes(value)
+ } else {
+ uc.availableChannelModes = stdChannelModes
+ }
+ case "CHANTYPES":
+ if !negate {
+ uc.availableChannelTypes = value
+ } else {
+ uc.availableChannelTypes = stdChannelTypes
+ }
+ case "PREFIX":
+ if !negate {
+ err = uc.handleMemberships(value)
+ } else {
+ uc.availableMemberships = stdMemberships
+ }
+ }
+ if err != nil {
+ return err
+ }
+
+ if passthroughIsupport[parameter] {
+ downstreamIsupport = append(downstreamIsupport, token)
+ }
+ }
+
+ uc.updateMonitor()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil {
+ return
+ }
+ msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport)
+ for _, msg := range msgs {
+ dc.SendMessage(msg)
+ }
+ })
+ case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD:
+ if !uc.casemapIsSet {
+ // upstream did not send any CASEMAPPING token, thus
+ // we assume it implements the old RFCs with rfc1459.
+ uc.casemapIsSet = true
+ uc.network.updateCasemapping(casemapRFC1459)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ if !uc.gotMotd {
+ // Ignore the initial MOTD upon connection, but forward
+ // subsequent MOTD messages downstream
+ uc.gotMotd = true
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case "BATCH":
+ var tag string
+ if err := parseMessageParams(msg, &tag); err != nil {
+ return err
+ }
+
+ if strings.HasPrefix(tag, "+") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; ok {
+ return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
+ }
+ var batchType string
+ if err := parseMessageParams(msg, nil, &batchType); err != nil {
+ return err
+ }
+ label := label
+ if label == "" && msgBatch != nil {
+ label = msgBatch.Label
+ }
+ uc.batches[tag] = batch{
+ Type: batchType,
+ Params: msg.Params[2:],
+ Outer: msgBatch,
+ Label: label,
+ }
+ } else if strings.HasPrefix(tag, "-") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; !ok {
+ return fmt.Errorf("unknown BATCH reference tag: %q", tag)
+ }
+ delete(uc.batches, tag)
+ } else {
+ return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
+ }
+ case "NICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newNick string
+ if err := parseMessageParams(msg, &newNick); err != nil {
+ return err
+ }
+
+ me := false
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
+ me = true
+ uc.nick = newNick
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ memberships := ch.Members.Value(msg.Prefix.Name)
+ if memberships != nil {
+ ch.Members.Delete(msg.Prefix.Name)
+ ch.Members.SetValue(newNick, memberships)
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if !me {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateNick()
+ })
+ uc.updateMonitor()
+ }
+ case "SETNAME":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newRealname string
+ if err := parseMessageParams(msg, &newRealname); err != nil {
+ return err
+ }
+
+ // TODO: consider appending this message to logs
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname)
+ uc.realname = newRealname
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateRealname()
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case "JOIN":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("joined channel %q", ch)
+ members := membershipsCasemapMap{newCasemapMap(0)}
+ members.casemap = uc.network.casemap
+ uc.channels.SetValue(ch, &upstreamChannel{
+ Name: ch,
+ conn: uc,
+ Members: members,
+ })
+ uc.updateChannelAutoDetach(ch)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MODE",
+ Params: []string{ch},
+ })
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.SetValue(msg.Prefix.Name, &memberships{})
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "PART":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("parted channel %q", ch)
+ uch := uc.channels.Value(ch)
+ if uch != nil {
+ uc.channels.Delete(ch)
+ uch.updateAutoDetach(0)
+ }
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(msg.Prefix.Name)
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "KICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channel, user string
+ if err := parseMessageParams(msg, &channel, &user); err != nil {
+ return err
+ }
+
+ if uc.isOurNick(user) {
+ uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
+ uc.channels.Delete(channel)
+ } else {
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(user)
+ }
+
+ uc.produce(channel, msg, nil)
+ case "QUIT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("quit")
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if ch.Members.Has(msg.Prefix.Name) {
+ ch.Members.Delete(msg.Prefix.Name)
+
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if msg.Prefix.Name != uc.nick {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
+ var name, topic string
+ if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if msg.Command == irc.RPL_TOPIC {
+ ch.Topic = topic
+ } else {
+ ch.Topic = ""
+ }
+ case "TOPIC":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ ch.Topic = msg.Params[1]
+ ch.TopicWho = msg.Prefix.Copy()
+ ch.TopicTime = time.Now() // TODO use msg.Tags["time"]
+ } else {
+ ch.Topic = ""
+ }
+ uc.produce(ch.Name, msg, nil)
+ case "MODE":
+ var name, modeStr string
+ if err := parseMessageParams(msg, &name, &modeStr); err != nil {
+ return err
+ }
+
+ if !uc.isChannel(name) { // user mode change
+ if name != uc.nick {
+ return fmt.Errorf("received MODE message for unknown nick %q", name)
+ }
+
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ } else { // channel mode change
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+
+ needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:])
+ if err != nil {
+ return err
+ }
+
+ uc.appendLog(ch.Name, msg)
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := make([]string, len(msg.Params))
+ params[0] = dc.marshalEntity(uc.network, name)
+ params[1] = modeStr
+
+ copy(params[2:], msg.Params[2:])
+ for i, modeParam := range params[2:] {
+ if _, ok := needMarshaling[i]; ok {
+ params[2+i] = dc.marshalEntity(uc.network, modeParam)
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "MODE",
+ Params: params,
+ })
+ })
+ }
+ }
+ case irc.RPL_UMODEIS:
+ if err := parseMessageParams(msg, nil); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ uc.modes = ""
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ case irc.RPL_CHANNELMODEIS:
+ var channel string
+ if err := parseMessageParams(msg, nil, &channel); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 2 {
+ modeStr = msg.Params[2]
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstMode := ch.modes == nil
+ ch.modes = make(map[byte]string)
+ if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil {
+ return err
+ }
+
+ c := uc.network.channels.Value(channel)
+ if firstMode && (c == nil || !c.Detached) {
+ modeStr, modeParams := ch.modes.Format()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ })
+ }
+ case rpl_creationtime:
+ var channel, creationTime string
+ if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstCreationTime := ch.creationTime == ""
+ ch.creationTime = creationTime
+
+ c := uc.network.channels.Value(channel)
+ if firstCreationTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime},
+ })
+ })
+ }
+ case rpl_topicwhotime:
+ var channel, who, timeStr string
+ if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstTopicWhoTime := ch.TopicWho == nil
+ ch.TopicWho = irc.ParsePrefix(who)
+ sec, err := strconv.ParseInt(timeStr, 10, 64)
+ if err != nil {
+ return fmt.Errorf("failed to parse topic time: %v", err)
+ }
+ ch.TopicTime = time.Unix(sec, 0)
+
+ c := uc.network.channels.Value(channel)
+ if firstTopicWhoTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{
+ dc.nick,
+ dc.marshalEntity(uc.network, ch.Name),
+ topicWho.String(),
+ timeStr,
+ },
+ })
+ })
+ }
+ case irc.RPL_LIST:
+ var channel, clients, topic string
+ if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LIST,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic},
+ })
+ case irc.RPL_LISTEND:
+ dc, cmd := uc.dequeueCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "End of /LIST"},
+ })
+ case irc.RPL_NAMREPLY:
+ var name, statusStr, members string
+ if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+ members := splitSpace(members)
+ for i, member := range members {
+ memberships, nick := uc.parseMembershipPrefix(member)
+ members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick)
+ }
+ memberStr := strings.Join(members, " ")
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, statusStr, channel, memberStr},
+ })
+ })
+ return nil
+ }
+
+ status, err := parseChannelStatus(statusStr)
+ if err != nil {
+ return err
+ }
+ ch.Status = status
+
+ for _, s := range splitSpace(members) {
+ memberships, nick := uc.parseMembershipPrefix(s)
+ ch.Members.SetValue(nick, memberships)
+ }
+ case irc.RPL_ENDOFNAMES:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, channel, "End of /NAMES list"},
+ })
+ })
+ return nil
+ }
+
+ if ch.complete {
+ return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
+ }
+ ch.complete = true
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ forwardChannel(ctx, dc, ch)
+ })
+ }
+ case irc.RPL_WHOREPLY:
+ var channel, username, host, server, nick, flags, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ if channel != "*" {
+ channel = dc.marshalEntity(uc.network, channel)
+ }
+ nick = dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing},
+ })
+ case rpl_whospcrpl:
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ // Only supported in single-upstream mode, so forward as-is
+ dc.SendMessage(msg)
+ case irc.RPL_ENDOFWHO:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.dequeueCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ mask := "*"
+ if len(cmd.Params) > 0 {
+ mask = cmd.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "End of /WHO list"},
+ })
+ case irc.RPL_WHOISUSER:
+ var nick, username, host, realname string
+ if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, nick, username, host, "*", realname},
+ })
+ })
+ case irc.RPL_WHOISSERVER:
+ var nick, server, serverInfo string
+ if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, nick, server, serverInfo},
+ })
+ })
+ case irc.RPL_WHOISOPERATOR:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, nick, "is an IRC operator"},
+ })
+ })
+ case irc.RPL_WHOISIDLE:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ params := []string{dc.nick, nick}
+ params = append(params, msg.Params[2:]...)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISIDLE,
+ Params: params,
+ })
+ })
+ case irc.RPL_WHOISCHANNELS:
+ var nick, channelList string
+ if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
+ return err
+ }
+ channels := splitSpace(channelList)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ channelList := make([]string, len(channels))
+ for i, channel := range channels {
+ prefix, channel := uc.parseMembershipPrefix(channel)
+ channel = dc.marshalEntity(uc.network, channel)
+ channelList[i] = prefix.Format(dc) + channel
+ }
+ channels := strings.Join(channelList, " ")
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISCHANNELS,
+ Params: []string{dc.nick, nick, channels},
+ })
+ })
+ case irc.RPL_ENDOFWHOIS:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, nick, "End of /WHOIS list"},
+ })
+ })
+ case "INVITE":
+ var nick, channel string
+ if err := parseMessageParams(msg, &nick, &channel); err != nil {
+ return err
+ }
+
+ weAreInvited := uc.isOurNick(nick)
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !weAreInvited && !dc.caps["invite-notify"] {
+ return
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "INVITE",
+ Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_INVITING:
+ var nick, channel string
+ if err := parseMessageParams(msg, nil, &nick, &channel); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_INVITING,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE:
+ var targetsStr string
+ if err := parseMessageParams(msg, nil, &targetsStr); err != nil {
+ return err
+ }
+ targets := strings.Split(targetsStr, ",")
+
+ online := msg.Command == irc.RPL_MONONLINE
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ uc.monitored.SetValue(prefix.Name, online)
+ }
+
+ // Check if the nick we want is now free
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if !online && uc.nickCM != wantNickCM {
+ found := false
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if uc.network.casemap(prefix.Name) == wantNickCM {
+ found = true
+ break
+ }
+ }
+ if found {
+ uc.logger.Printf("desired nick %q is now available", wantNick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{wantNick},
+ })
+ }
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if dc.monitored.Has(prefix.Name) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ })
+ case irc.ERR_MONLISTFULL:
+ var limit, targetsStr string
+ if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil {
+ return err
+ }
+
+ targets := strings.Split(targetsStr, ",")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ if dc.monitored.Has(target) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, limit, target},
+ })
+ }
+ }
+ })
+ case irc.RPL_AWAY:
+ var nick, reason string
+ if err := parseMessageParams(msg, nil, &nick, &reason); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_AWAY,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason},
+ })
+ })
+ case "AWAY", "ACCOUNT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST:
+ var channel, mask string
+ if err := parseMessageParams(msg, nil, &channel, &mask); err != nil {
+ return err
+ }
+ var addNick, addTime string
+ if len(msg.Params) >= 5 {
+ addNick = msg.Params[3]
+ addTime = msg.Params[4]
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, channel)
+
+ var params []string
+ if addNick != "" && addTime != "" {
+ addNick := dc.marshalEntity(uc.network, addNick)
+ params = []string{dc.nick, channel, mask, addNick, addTime}
+ } else {
+ params = []string{dc.nick, channel, mask}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST:
+ var channel, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ upstreamChannel := dc.marshalEntity(uc.network, channel)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, upstreamChannel, trailing},
+ })
+ })
+ case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN:
+ var command, reason string
+ if err := parseMessageParams(msg, nil, &command, &reason); err != nil {
+ return err
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, command, reason},
+ })
+ })
+ case "FAIL":
+ var command, code string
+ if err := parseMessageParams(msg, &command, &code); err != nil {
+ return err
+ }
+
+ if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" {
+ return registrationError{msg}
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case "ACK":
+ // Ignore
+ case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
+ // Ignore
+ case irc.RPL_YOURHOST, irc.RPL_CREATED:
+ // Ignore
+ case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
+ fallthrough
+ case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
+ fallthrough
+ case rpl_localusers, rpl_globalusers:
+ fallthrough
+ case irc.RPL_MOTDSTART, irc.RPL_MOTD:
+ // Ignore these messages if they're part of the initial registration
+ // message burst. Forward them if the user explicitly asked for them.
+ if !uc.gotMotd {
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_LISTSTART:
+ // Ignore
+ case "ERROR":
+ var text string
+ if err := parseMessageParams(msg, &text); err != nil {
+ return err
+ }
+ return fmt.Errorf("fatal server error: %v", text)
+ case irc.ERR_NICKNAMEINUSE:
+ // At this point, we haven't received ISUPPORT so we don't know the
+ // maximum nickname length or whether the server supports MONITOR. Many
+ // servers have NICKLEN=30 so let's just use that.
+ if !uc.registered && len(uc.nick)+1 < 30 {
+ uc.nick = uc.nick + "_"
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ return nil
+ }
+ fallthrough
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP:
+ if !uc.registered {
+ return registrationError{msg}
+ }
+ fallthrough
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ // best effort marshaling for unknown messages, replies and errors:
+ // most numerics start with the user nick, marshal it if that's the case
+ // otherwise, conservately keep the params without marshaling
+ params := msg.Params
+ if _, err := strconv.Atoi(msg.Command); err == nil { // numeric
+ if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) {
+ params[0] = dc.nick
+ }
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ }
+ return nil
+}
+
+func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) {
+ if uc.network.detachedMessageNeedsRelay(ch, msg) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.relayDetachedMessage(uc.network, msg)
+ })
+ }
+ if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
+ uc.network.attach(ctx, ch)
+ if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
+ }
+ }
+}
+
+func (uc *upstreamConn) handleChanModes(s string) error {
+ parts := strings.SplitN(s, ",", 5)
+ if len(parts) < 4 {
+ return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s)
+ }
+ modes := make(map[byte]channelModeType)
+ for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
+ for j := 0; j < len(parts[i]); j++ {
+ mode := parts[i][j]
+ modes[mode] = mt
+ }
+ }
+ uc.availableChannelModes = modes
+ return nil
+}
+
+func (uc *upstreamConn) handleMemberships(s string) error {
+ if s == "" {
+ uc.availableMemberships = nil
+ return nil
+ }
+
+ if s[0] != '(' {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ sep := strings.IndexByte(s, ')')
+ if sep < 0 || len(s) != sep*2 {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ memberships := make([]membership, len(s)/2-1)
+ for i := range memberships {
+ memberships[i] = membership{
+ Mode: s[i+1],
+ Prefix: s[sep+i+1],
+ }
+ }
+ uc.availableMemberships = memberships
+ return nil
+}
+
+func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
+ caps := strings.Fields(capsStr)
+ for _, s := range caps {
+ kv := strings.SplitN(s, "=", 2)
+ k := strings.ToLower(kv[0])
+ var v string
+ if len(kv) == 2 {
+ v = kv[1]
+ }
+ uc.supportedCaps[k] = v
+ }
+}
+
+func (uc *upstreamConn) requestCaps() {
+ var requestCaps []string
+ for c := range permanentUpstreamCaps {
+ if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] {
+ requestCaps = append(requestCaps, c)
+ }
+ }
+
+ if len(requestCaps) == 0 {
+ return
+ }
+
+ uc.SendMessage(context.TODO(), &irc.Message{
+ Command: "CAP",
+ Params: []string{"REQ", strings.Join(requestCaps, " ")},
+ })
+}
+
+func (uc *upstreamConn) supportsSASL(mech string) bool {
+ v, ok := uc.supportedCaps["sasl"]
+ if !ok {
+ return false
+ }
+
+ if v == "" {
+ return true
+ }
+
+ mechanisms := strings.Split(v, ",")
+ for _, mech := range mechanisms {
+ if strings.EqualFold(mech, mech) {
+ return true
+ }
+ }
+ return false
+}
+
+func (uc *upstreamConn) requestSASL() bool {
+ if uc.network.SASL.Mechanism == "" {
+ return false
+ }
+ return uc.supportsSASL(uc.network.SASL.Mechanism)
+}
+
+func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error {
+ uc.caps[name] = ok
+
+ switch name {
+ case "sasl":
+ if !uc.requestSASL() {
+ return nil
+ }
+ if !ok {
+ uc.logger.Printf("server refused to acknowledge the SASL capability")
+ return nil
+ }
+
+ auth := &uc.network.SASL
+ switch auth.Mechanism {
+ case "PLAIN":
+ uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
+ uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
+ case "EXTERNAL":
+ uc.logger.Printf("starting SASL EXTERNAL authentication")
+ uc.saslClient = sasl.NewExternalClient("")
+ default:
+ return fmt.Errorf("unsupported SASL mechanism %q", name)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{auth.Mechanism},
+ })
+ default:
+ if permanentUpstreamCaps[name] {
+ break
+ }
+ uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
+ }
+ return nil
+}
+
+func splitSpace(s string) []string {
+ return strings.FieldsFunc(s, func(r rune) bool {
+ return r == ' '
+ })
+}
+
+func (uc *upstreamConn) register(ctx context.Context) {
+ uc.nick = GetNick(&uc.user.User, &uc.network.Network)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.username = GetUsername(&uc.user.User, &uc.network.Network)
+ uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"LS", "302"},
+ })
+
+ if uc.network.Pass != "" {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PASS",
+ Params: []string{uc.network.Pass},
+ })
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "USER",
+ Params: []string{uc.username, "0", "*", uc.realname},
+ })
+}
+
+func (uc *upstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := uc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error {
+ for !uc.registered {
+ msg, err := uc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read message: %v", err)
+ }
+
+ if err := uc.handleMessage(ctx, msg); err != nil {
+ if _, ok := err.(registrationError); ok {
+ return err
+ } else {
+ msg.Tags = nil // prevent message tags from cluttering logs
+ return fmt.Errorf("failed to handle message %q: %v", msg, err)
+ }
+ }
+ }
+
+ for _, command := range uc.network.ConnectCommands {
+ m, err := irc.ParseMessage(command)
+ if err != nil {
+ uc.logger.Printf("failed to parse connect command %q: %v", command, err)
+ } else {
+ uc.SendMessage(ctx, m)
+ }
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := uc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventUpstreamMessage{msg, uc}
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
+ if !uc.caps["message-tags"] {
+ msg = msg.Copy()
+ msg.Tags = nil
+ }
+
+ uc.conn.SendMessage(ctx, msg)
+}
+
+func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
+ if uc.caps["labeled-response"] {
+ if msg.Tags == nil {
+ msg.Tags = make(map[string]irc.TagValue)
+ }
+ msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
+ uc.nextLabelID++
+ }
+ uc.SendMessage(ctx, msg)
+}
+
+// appendLog appends a message to the log file.
+//
+// The internal message ID is returned. If the message isn't recorded in the
+// log file, an empty string is returned.
+func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) {
+ if uc.user.msgStore == nil {
+ return ""
+ }
+
+ // Don't store messages with a server mask target
+ if strings.HasPrefix(entity, "$") {
+ return ""
+ }
+
+ entityCM := uc.network.casemap(entity)
+ if entityCM == "nickserv" {
+ // The messages sent/received from NickServ may contain
+ // security-related information (like passwords). Don't store these.
+ return ""
+ }
+
+ if !uc.network.delivered.HasTarget(entity) {
+ // This is the first message we receive from this target. Save the last
+ // message ID in delivery receipts, so that we can send the new message
+ // in the backlog if an offline client reconnects.
+ lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now())
+ if err != nil {
+ uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
+ return ""
+ }
+
+ uc.network.delivered.ForEachClient(func(clientName string) {
+ uc.network.delivered.StoreID(entity, clientName, lastID)
+ })
+ }
+
+ msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg)
+ if err != nil {
+ uc.logger.Printf("failed to append message to store: %v", err)
+ return ""
+ }
+
+ return msgID
+}
+
+// produce appends a message to the logs and forwards it to connected downstream
+// connections.
+//
+// If origin is not nil and origin doesn't support echo-message, the message is
+// forwarded to all connections except origin.
+func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) {
+ var msgID string
+ if target != "" {
+ msgID = uc.appendLog(target, msg)
+ }
+
+ // Don't forward messages if it's a detached channel
+ ch := uc.network.channels.Value(target)
+ detached := ch != nil && ch.Detached
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !detached && (dc != origin || dc.caps["echo-message"]) {
+ dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID)
+ } else {
+ dc.advanceMessageWithID(msg, msgID)
+ }
+ })
+}
+
+func (uc *upstreamConn) updateAway() {
+ ctx := context.TODO()
+
+ away := true
+ uc.forEachDownstream(func(*downstreamConn) {
+ away = false
+ })
+ if away == uc.away {
+ return
+ }
+ if away {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ Params: []string{"Auto away"},
+ })
+ } else {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ })
+ }
+ uc.away = away
+}
+
+func (uc *upstreamConn) updateChannelAutoDetach(name string) {
+ uch := uc.channels.Value(name)
+ if uch == nil {
+ return
+ }
+ ch := uc.network.channels.Value(name)
+ if ch == nil || ch.Detached {
+ return
+ }
+ uch.updateAutoDetach(ch.DetachAfter)
+}
+
+func (uc *upstreamConn) updateMonitor() {
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return
+ }
+
+ ctx := context.TODO()
+
+ add := make(map[string]struct{})
+ var addList []string
+ seen := make(map[string]struct{})
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for targetCM := range dc.monitored.innerMap {
+ if !uc.monitored.Has(targetCM) {
+ if _, ok := add[targetCM]; !ok {
+ addList = append(addList, targetCM)
+ add[targetCM] = struct{}{}
+ }
+ } else {
+ seen[targetCM] = struct{}{}
+ }
+ }
+ })
+
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
+ addList = append(addList, wantNickCM)
+ add[wantNickCM] = struct{}{}
+ }
+
+ removeAll := true
+ var removeList []string
+ for targetCM, entry := range uc.monitored.innerMap {
+ if _, ok := seen[targetCM]; ok {
+ removeAll = false
+ } else {
+ removeList = append(removeList, entry.originalKey)
+ }
+ }
+
+ // TODO: better handle the case where len(uc.monitored) + len(addList)
+ // exceeds the limit, probably by immediately sending ERR_MONLISTFULL?
+
+ if removeAll && len(addList) == 0 && len(removeList) > 0 {
+ // Optimization when the last MONITOR-aware downstream disconnects
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{"C"},
+ })
+ } else {
+ msgs := generateMonitor("-", removeList)
+ msgs = append(msgs, generateMonitor("+", addList)...)
+ for _, msg := range msgs {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+
+ for _, target := range removeList {
+ uc.monitored.Delete(target)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/binary"
+ "encoding/hex"
+ "fmt"
+ "math/big"
+ "net"
+ "sort"
+ "strings"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+type event interface{}
+
+type eventUpstreamMessage struct {
+ msg *irc.Message
+ uc *upstreamConn
+}
+
+type eventUpstreamConnectionError struct {
+ net *network
+ err error
+}
+
+type eventUpstreamConnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamDisconnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamError struct {
+ uc *upstreamConn
+ err error
+}
+
+type eventDownstreamMessage struct {
+ msg *irc.Message
+ dc *downstreamConn
+}
+
+type eventDownstreamConnected struct {
+ dc *downstreamConn
+}
+
+type eventDownstreamDisconnected struct {
+ dc *downstreamConn
+}
+
+type eventChannelDetach struct {
+ uc *upstreamConn
+ name string
+}
+
+type eventBroadcast struct {
+ msg *irc.Message
+}
+
+type eventStop struct{}
+
+type eventUserUpdate struct {
+ password *string
+ admin *bool
+ done chan error
+}
+
+type deliveredClientMap map[string]string // client name -> msg ID
+
+type deliveredStore struct {
+ m deliveredCasemapMap
+}
+
+func newDeliveredStore() deliveredStore {
+ return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
+}
+
+func (ds deliveredStore) HasTarget(target string) bool {
+ return ds.m.Value(target) != nil
+}
+
+func (ds deliveredStore) LoadID(target, clientName string) string {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ return ""
+ }
+ return clients[clientName]
+}
+
+func (ds deliveredStore) StoreID(target, clientName, msgID string) {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ clients = make(deliveredClientMap)
+ ds.m.SetValue(target, clients)
+ }
+ clients[clientName] = msgID
+}
+
+func (ds deliveredStore) ForEachTarget(f func(target string)) {
+ for _, entry := range ds.m.innerMap {
+ f(entry.originalKey)
+ }
+}
+
+func (ds deliveredStore) ForEachClient(f func(clientName string)) {
+ clients := make(map[string]struct{})
+ for _, entry := range ds.m.innerMap {
+ delivered := entry.value.(deliveredClientMap)
+ for clientName := range delivered {
+ clients[clientName] = struct{}{}
+ }
+ }
+
+ for clientName := range clients {
+ f(clientName)
+ }
+}
+
+type network struct {
+ Network
+ user *user
+ logger Logger
+ stopped chan struct{}
+
+ conn *upstreamConn
+ channels channelCasemapMap
+ delivered deliveredStore
+ lastError error
+ casemap casemapping
+}
+
+func newNetwork(user *user, record *Network, channels []Channel) *network {
+ logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
+
+ m := channelCasemapMap{newCasemapMap(0)}
+ for _, ch := range channels {
+ ch := ch
+ m.SetValue(ch.Name, &ch)
+ }
+
+ return &network{
+ Network: *record,
+ user: user,
+ logger: logger,
+ stopped: make(chan struct{}),
+ channels: m,
+ delivered: newDeliveredStore(),
+ casemap: casemapRFC1459,
+ }
+}
+
+func (net *network) forEachDownstream(f func(*downstreamConn)) {
+ net.user.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ if dc.network != nil && dc.network != net {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (net *network) isStopped() bool {
+ select {
+ case <-net.stopped:
+ return true
+ default:
+ return false
+ }
+}
+
+func userIdent(u *User) string {
+ // The ident is a string we will send to upstream servers in clear-text.
+ // For privacy reasons, make sure it doesn't expose any meaningful user
+ // metadata. We just use the base64-encoded hashed ID, so that people don't
+ // start relying on the string being an integer or following a pattern.
+ var b [64]byte
+ binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
+ h := sha256.Sum256(b[:])
+ return hex.EncodeToString(h[:16])
+}
+
+func (net *network) run() {
+ if !net.Enabled {
+ return
+ }
+
+ var lastTry time.Time
+ backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter)
+ for {
+ if net.isStopped() {
+ return
+ }
+
+ delay := backoff.Next() - time.Now().Sub(lastTry)
+ if delay > 0 {
+ net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
+ time.Sleep(delay)
+ }
+ lastTry = time.Now()
+
+
+ uc, err := connectToUpstream(context.TODO(), net)
+ if err != nil {
+ net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
+ continue
+ }
+
+ uc.register(context.TODO())
+ if err := uc.runUntilRegistered(context.TODO()); err != nil {
+ text := err.Error()
+ temp := true
+ if regErr, ok := err.(registrationError); ok {
+ text = regErr.Reason()
+ temp = regErr.Temporary()
+ }
+ uc.logger.Printf("failed to register: %v", text)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
+ uc.Close()
+ if !temp {
+ return
+ }
+ continue
+ }
+
+ // TODO: this is racy with net.stopped. If the network is stopped
+ // before the user goroutine receives eventUpstreamConnected, the
+ // connection won't be closed.
+ net.user.events <- eventUpstreamConnected{uc}
+ if err := uc.readMessages(net.user.events); err != nil {
+ uc.logger.Printf("failed to handle messages: %v", err)
+ net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
+ }
+ uc.Close()
+ net.user.events <- eventUpstreamDisconnected{uc}
+
+ backoff.Reset()
+ }
+}
+
+func (net *network) stop() {
+ if !net.isStopped() {
+ close(net.stopped)
+ }
+
+ if net.conn != nil {
+ net.conn.Close()
+ }
+}
+
+func (net *network) detach(ch *Channel) {
+ if ch.Detached {
+ return
+ }
+
+ net.logger.Printf("detaching channel %q", ch.Name)
+
+ ch.Detached = true
+
+ if net.user.msgStore != nil {
+ nameCM := net.casemap(ch.Name)
+ lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
+ if err != nil {
+ net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
+ }
+ ch.DetachedInternalMsgID = lastID
+ }
+
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "PART",
+ Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
+ })
+ })
+}
+
+func (net *network) attach(ctx context.Context, ch *Channel) {
+ if !ch.Detached {
+ return
+ }
+
+ net.logger.Printf("attaching channel %q", ch.Name)
+
+ detachedMsgID := ch.DetachedInternalMsgID
+ ch.Detached = false
+ ch.DetachedInternalMsgID = ""
+
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+
+ net.conn.updateChannelAutoDetach(ch.Name)
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(net, ch.Name)},
+ })
+
+ if uch != nil {
+ forwardChannel(ctx, dc, uch)
+ }
+
+ if detachedMsgID != "" {
+ dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID)
+ }
+ })
+}
+
+func (net *network) deleteChannel(ctx context.Context, name string) error {
+ ch := net.channels.Value(name)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
+ return err
+ }
+ net.channels.Delete(name)
+ return nil
+}
+
+func (net *network) updateCasemapping(newCasemap casemapping) {
+ net.casemap = newCasemap
+ net.channels.SetCasemapping(newCasemap)
+ net.delivered.m.SetCasemapping(newCasemap)
+ if uc := net.conn; uc != nil {
+ uc.channels.SetCasemapping(newCasemap)
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.Members.SetCasemapping(newCasemap)
+ }
+ uc.monitored.SetCasemapping(newCasemap)
+ }
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.monitored.SetCasemapping(newCasemap)
+ })
+}
+
+func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) {
+ if !net.user.hasPersistentMsgStore() {
+ return
+ }
+
+ var receipts []DeliveryReceipt
+ net.delivered.ForEachTarget(func(target string) {
+ msgID := net.delivered.LoadID(target, clientName)
+ if msgID == "" {
+ return
+ }
+ receipts = append(receipts, DeliveryReceipt{
+ Target: target,
+ InternalMsgID: msgID,
+ })
+ })
+
+ if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil {
+ net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
+ }
+}
+
+func (net *network) isHighlight(msg *irc.Message) bool {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return false
+ }
+
+ text := msg.Params[1]
+
+ nick := net.Nick
+ if net.conn != nil {
+ nick = net.conn.nick
+ }
+
+ // TODO: use case-mapping aware comparison here
+ return msg.Prefix.Name != nick && isHighlight(text, nick)
+}
+
+func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
+ highlight := net.isHighlight(msg)
+ return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
+}
+
+func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
+ // User may have e.g. EXTERNAL mechanism configured. We do not want to
+ // automatically erase the key pair or any other credentials.
+ if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
+ return
+ }
+
+ net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
+ net.SASL.Mechanism = "PLAIN"
+ net.SASL.Plain.Username = username
+ net.SASL.Plain.Password = password
+ if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
+ net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
+ }
+}
+
+type user struct {
+ User
+ srv *Server
+ logger Logger
+
+ events chan event
+ done chan struct{}
+
+ networks []*network
+ downstreamConns []*downstreamConn
+ msgStore messageStore
+}
+
+func newUser(srv *Server, record *User) *user {
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
+
+ var msgStore messageStore
+ if logPath := srv.Config().LogPath; logPath != "" {
+ msgStore = newFSMessageStore(logPath, record)
+ } else {
+ msgStore = newMemoryMessageStore()
+ }
+
+ return &user{
+ User: *record,
+ srv: srv,
+ logger: logger,
+ events: make(chan event, 64),
+ done: make(chan struct{}),
+ msgStore: msgStore,
+ }
+}
+
+func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
+ for _, network := range u.networks {
+ if network.conn == nil {
+ continue
+ }
+ f(network.conn)
+ }
+}
+
+func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
+ for _, dc := range u.downstreamConns {
+ f(dc)
+ }
+}
+
+func (u *user) getNetwork(name string) *network {
+ for _, network := range u.networks {
+ if network.Addr == name {
+ return network
+ }
+ if network.Name != "" && network.Name == name {
+ return network
+ }
+ }
+ return nil
+}
+
+func (u *user) getNetworkByID(id int64) *network {
+ for _, net := range u.networks {
+ if net.ID == id {
+ return net
+ }
+ }
+ return nil
+}
+
+func (u *user) run() {
+ defer func() {
+ if u.msgStore != nil {
+ if err := u.msgStore.Close(); err != nil {
+ u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
+ }
+ }
+ close(u.done)
+ }()
+
+ networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
+ if err != nil {
+ u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
+ return
+ }
+
+ sort.Slice(networks, func(i, j int) bool {
+ return networks[i].ID < networks[j].ID
+ })
+
+ for _, record := range networks {
+ record := record
+ channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
+ continue
+ }
+
+ network := newNetwork(u, &record, channels)
+ u.networks = append(u.networks, network)
+
+ if u.hasPersistentMsgStore() {
+ receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
+ return
+ }
+
+ for _, rcpt := range receipts {
+ network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
+ }
+ }
+
+ go network.run()
+ }
+
+ for e := range u.events {
+ switch e := e.(type) {
+ case eventUpstreamConnected:
+ uc := e.uc
+
+ uc.network.conn = uc
+
+ uc.updateAway()
+ uc.updateMonitor()
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+ })
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=connected"},
+ })
+ }
+ })
+ uc.network.lastError = nil
+ case eventUpstreamDisconnected:
+ u.handleUpstreamDisconnected(e.uc)
+ case eventUpstreamConnectionError:
+ net := e.net
+
+ stopped := false
+ select {
+ case <-net.stopped:
+ stopped = true
+ default:
+ }
+
+ if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
+ net.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
+ })
+ }
+ net.lastError = e.err
+ case eventUpstreamError:
+ uc := e.uc
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
+ })
+ uc.network.lastError = e.err
+ case eventUpstreamMessage:
+ msg, uc := e.msg, e.uc
+ if uc.isClosed() {
+ uc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ if err := uc.handleMessage(context.TODO(), msg); err != nil {
+ uc.logger.Printf("failed to handle message %q: %v", msg, err)
+ }
+ case eventChannelDetach:
+ uc, name := e.uc, e.name
+ c := uc.network.channels.Value(name)
+ if c == nil || c.Detached {
+ continue
+ }
+ uc.network.detach(c)
+ if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
+ u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
+ }
+ case eventDownstreamConnected:
+ dc := e.dc
+
+ if dc.network != nil {
+ dc.monitored.SetCasemapping(dc.network.casemap)
+ }
+
+ if err := dc.welcome(context.TODO()); err != nil {
+ dc.logger.Printf("failed to handle new registered connection: %v", err)
+ break
+ }
+
+ u.downstreamConns = append(u.downstreamConns, dc)
+
+ dc.forEachNetwork(func(network *network) {
+ if network.lastError != nil {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
+ }
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.updateAway()
+ })
+ case eventDownstreamDisconnected:
+ dc := e.dc
+
+ for i := range u.downstreamConns {
+ if u.downstreamConns[i] == dc {
+ u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
+ break
+ }
+ }
+
+ dc.forEachNetwork(func(net *network) {
+ net.storeClientDeliveryReceipts(context.TODO(), dc.clientName)
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.cancelPendingCommandsByDownstreamID(dc.id)
+ uc.updateAway()
+ uc.updateMonitor()
+ })
+ case eventDownstreamMessage:
+ msg, dc := e.msg, e.dc
+ if dc.isClosed() {
+ dc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ err := dc.handleMessage(context.TODO(), msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ dc.logger.Printf("failed to handle message %q: %v", msg, err)
+ dc.Close()
+ }
+ case eventBroadcast:
+ msg := e.msg
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case eventUserUpdate:
+ // copy the user record because we'll mutate it
+ record := u.User
+
+ if e.password != nil {
+ record.Password = *e.password
+ }
+ if e.admin != nil {
+ record.Admin = *e.admin
+ }
+
+ e.done <- u.updateUser(context.TODO(), &record)
+
+ // If the password was updated, kill all downstream connections to
+ // force them to re-authenticate with the new credentials.
+ if e.password != nil {
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ }
+ case eventStop:
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ for _, n := range u.networks {
+ n.stop()
+
+ n.delivered.ForEachClient(func(clientName string) {
+ n.storeClientDeliveryReceipts(context.TODO(), clientName)
+ })
+ }
+ return
+ default:
+ panic(fmt.Sprintf("received unknown event type: %T", e))
+ }
+ }
+}
+
+func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
+ uc.network.conn = nil
+
+ uc.abortPendingCommands()
+
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.updateAutoDetach(0)
+ }
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+
+ // If the network has been removed, don't send a state change notification
+ found := false
+ for _, net := range u.networks {
+ if net == uc.network {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return
+ }
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=disconnected"},
+ })
+ }
+ })
+
+ if uc.network.lastError == nil {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
+ }
+ })
+ }
+}
+
+func (u *user) addNetwork(network *network) {
+ u.networks = append(u.networks, network)
+
+ sort.Slice(u.networks, func(i, j int) bool {
+ return u.networks[i].ID < u.networks[j].ID
+ })
+
+ go network.run()
+}
+
+func (u *user) removeNetwork(network *network) {
+ network.stop()
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.Close()
+ }
+ })
+
+ for i, net := range u.networks {
+ if net == network {
+ u.networks = append(u.networks[:i], u.networks[i+1:]...)
+ return
+ }
+ }
+
+ panic("tried to remove a non-existing network")
+}
+
+func (u *user) checkNetwork(record *Network) error {
+ url, err := record.URL()
+ if err != nil {
+ return err
+ }
+ if url.User != nil {
+ return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme)
+ }
+ if url.RawQuery != "" {
+ return fmt.Errorf("%v:// URL must not have query values", url.Scheme)
+ }
+ if url.Fragment != "" {
+ return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme)
+ }
+ switch url.Scheme {
+ case "ircs", "irc":
+ if url.Host == "" {
+ return fmt.Errorf("%v:// URL must have a host", url.Scheme)
+ }
+ if url.Path != "" {
+ return fmt.Errorf("%v:// URL must not have a path", url.Scheme)
+ }
+ case "irc+unix", "unix":
+ if url.Host != "" {
+ return fmt.Errorf("%v:// URL must not have a host", url.Scheme)
+ }
+ if url.Path == "" {
+ return fmt.Errorf("%v:// URL must have a path", url.Scheme)
+ }
+ default:
+ return fmt.Errorf("unknown URL scheme %q", url.Scheme)
+ }
+
+ if record.GetName() == "" {
+ return fmt.Errorf("network name cannot be empty")
+ }
+ if strings.HasPrefix(record.GetName(), "-") {
+ // Can be mixed up with flags when sending commands to the service
+ return fmt.Errorf("network name cannot start with a dash character")
+ }
+
+ for _, net := range u.networks {
+ if net.GetName() == record.GetName() && net.ID != record.ID {
+ return fmt.Errorf("a network with the name %q already exists", record.GetName())
+ }
+ }
+
+ return nil
+}
+
+func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID != 0 {
+ panic("tried creating an already-existing network")
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
+ return nil, fmt.Errorf("maximum number of networks reached")
+ }
+
+ network := newNetwork(u, record, nil)
+ err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
+ if err != nil {
+ return nil, err
+ }
+
+ u.addNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return network, nil
+}
+
+func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID == 0 {
+ panic("tried updating a new network")
+ }
+
+ // If the realname is reset to the default, just wipe the per-network
+ // setting
+ if record.Realname == u.Realname {
+ record.Realname = ""
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ network := u.getNetworkByID(record.ID)
+ if network == nil {
+ panic("tried updating a non-existing network")
+ }
+
+ if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
+ return nil, err
+ }
+
+ // Most network changes require us to re-connect to the upstream server
+
+ channels := make([]Channel, 0, network.channels.Len())
+ for _, entry := range network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, *ch)
+ }
+
+ updatedNetwork := newNetwork(u, record, channels)
+
+ // If we're currently connected, disconnect and perform the necessary
+ // bookkeeping
+ if network.conn != nil {
+ network.stop()
+ // Note: this will set network.conn to nil
+ u.handleUpstreamDisconnected(network.conn)
+ }
+
+ // Patch downstream connections to use our fresh updated network
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.network = updatedNetwork
+ }
+ })
+
+ // We need to remove the network after patching downstream connections,
+ // otherwise they'll get closed
+ u.removeNetwork(network)
+
+ // The filesystem message store needs to be notified whenever the network
+ // is renamed
+ fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
+ if isFS && updatedNetwork.GetName() != network.GetName() {
+ if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
+ network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
+ }
+ }
+
+ // This will re-connect to the upstream server
+ u.addNetwork(updatedNetwork)
+
+ // TODO: only broadcast attributes that have changed
+ idStr := fmt.Sprintf("%v", updatedNetwork.ID)
+ attrs := getNetworkAttrs(updatedNetwork)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return updatedNetwork, nil
+}
+
+func (u *user) deleteNetwork(ctx context.Context, id int64) error {
+ network := u.getNetworkByID(id)
+ if network == nil {
+ panic("tried deleting a non-existing network")
+ }
+
+ if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
+ return err
+ }
+
+ u.removeNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, "*"},
+ })
+ }
+ })
+
+ return nil
+}
+
+func (u *user) updateUser(ctx context.Context, record *User) error {
+ if u.ID != record.ID {
+ panic("ID mismatch when updating user")
+ }
+
+ realnameUpdated := u.Realname != record.Realname
+ if err := u.srv.db.StoreUser(ctx, record); err != nil {
+ return fmt.Errorf("failed to update user %q: %v", u.Username, err)
+ }
+ u.User = *record
+
+ if realnameUpdated {
+ // Re-connect to networks which use the default realname
+ var needUpdate []Network
+ for _, net := range u.networks {
+ if net.Realname == "" {
+ needUpdate = append(needUpdate, net.Network)
+ }
+ }
+
+ var netErr error
+ for _, net := range needUpdate {
+ if _, err := u.updateNetwork(ctx, &net); err != nil {
+ netErr = err
+ }
+ }
+ if netErr != nil {
+ return netErr
+ }
+ }
+
+ return nil
+}
+
+func (u *user) stop() {
+ u.events <- eventStop{}
+ <-u.done
+}
+
+func (u *user) hasPersistentMsgStore() bool {
+ if u.msgStore == nil {
+ return false
+ }
+ _, isMem := u.msgStore.(*memoryMessageStore)
+ return !isMem
+}
+
+// localAddrForHost returns the local address to use when connecting to host.
+// A nil address is returned when the OS should automatically pick one.
+func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
+ upstreamUserIPs := u.srv.Config().UpstreamUserIPs
+ if len(upstreamUserIPs) == 0 {
+ return nil, nil
+ }
+
+ ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
+ if err != nil {
+ return nil, err
+ }
+
+ wantIPv6 := false
+ for _, ip := range ips {
+ if ip.To4() == nil {
+ wantIPv6 = true
+ break
+ }
+ }
+
+ var ipNet *net.IPNet
+ for _, in := range upstreamUserIPs {
+ if wantIPv6 == (in.IP.To4() == nil) {
+ ipNet = in
+ break
+ }
+ }
+ if ipNet == nil {
+ return nil, nil
+ }
+
+ var ipInt big.Int
+ ipInt.SetBytes(ipNet.IP)
+ ipInt.Add(&ipInt, big.NewInt(u.ID+1))
+ ip := net.IP(ipInt.Bytes())
+ if !ipNet.Contains(ip) {
+ return nil, fmt.Errorf("IP network %v too small", ipNet)
+ }
+
+ return &net.TCPAddr{IP: ip}, nil
+}
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "runtime/debug"
+ "strings"
+)
+
+const (
+ defaultVersion = "0.0.0"
+ defaultCommit = "HEAD"
+ defaultBuild = "0000-01-01:00:00+00:00"
+)
+
+var (
+ // Version is the tagged release version in the form <major>.<minor>.<patch>
+ // following semantic versioning and is overwritten by the build system.
+ Version = defaultVersion
+
+ // Commit is the commit sha of the build (normally from Git) and is overwritten
+ // by the build system.
+ Commit = defaultCommit
+
+ // Build is the date and time of the build as an RFC3339 formatted string
+ // and is overwritten by the build system.
+ Build = defaultBuild
+)
+
+// FullVersion display the full version and build
+func FullVersion() string {
+ var sb strings.Builder
+
+ isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild
+
+ if !isDefault {
+ sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build))
+ }
+
+ if info, ok := debug.ReadBuildInfo(); ok {
+ if isDefault {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Version))
+ }
+ sb.WriteString(fmt.Sprintf(" %s", info.GoVersion))
+ if info.Main.Sum != "" {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum))
+ }
+ }
+
+ return sb.String()
+}
--- /dev/null
+vendor
+/suika
+/suikadb
+/suika-znc-import
+/suika.db
--- /dev/null
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+ <one line to give the program's name and a brief idea of what it does.>
+ Copyright (C) <year> <name of author>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+<https://www.gnu.org/licenses/>.
--- /dev/null
+GO ?= go
+RM ?= rm
+GOFLAGS ?= -v -ldflags "-w -X `go list`.Version=${VERSION} -X `go list`.Commit=${COMMIT} -X `go list`.Build=${BUILD}" -mod=vendor
+PREFIX ?= /usr/local
+BINDIR ?= bin
+MANDIR ?= share/man
+MKDIR ?= mkdir
+CP ?= cp
+SYSCONFDIR ?= /etc
+ASCIIDOCTOR ?= asciidoctor
+
+VERSION = `git describe --abbrev=0 --tags 2>/dev/null || echo "$VERSION"`
+COMMIT = `git rev-parse --short HEAD || echo "$COMMIT"`
+BRANCH = `git rev-parse --abbrev-ref HEAD`
+BUILD = `git show -s --pretty=format:%cI`
+
+GOARCH ?= amd64
+GOOS ?= linux
+
+all: build
+
+build: vendor
+ ${GO} build ${GOFLAGS} ./cmd/suika
+ ${GO} build ${GOFLAGS} ./cmd/suikadb
+ ${GO} build ${GOFLAGS} ./cmd/suika-znc-import
+clean:
+ ${RM} -f suika suikadb suika-znc-import
+install:
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${BINDIR}
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ ${MKDIR} -p ${DESTDIR}${PREFIX}/${MANDIR}/man7
+ ${MKDIR} -p ${DESTDIR}${SYSCONFDIR}/suika
+ ${MKDIR} -p ${DESTDIR}/var/lib/suika
+ ${CP} -f suika suikadb suika-znc-import ${DESTDIR}${PREFIX}/${BINDIR}
+ ${CP} -f doc/suika.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suikadb.1 ${DESTDIR}${PREFIX}/${MANDIR}/man1
+ ${CP} -f doc/suika-znc-import.1 ${DESTDIR}/${MANDIR}/man1
+ ${CP} -f doc/suika-config.5 ${DESTDIR}${PREFIX}/${MANDIR}/man5
+ [ -f ${DESTDIR}${SYSCONFDIR}/suika/config ] || ${CP} -f config.in ${DESTDIR}${SYSCONFDIR}/suika/config
+test:
+ go test
+vendor:
+ go mod vendor
+.PHONY: build clean install
--- /dev/null
+# suika
+
+[![Go Documentation](https://godocs.io/marisa.chaotic.ninja/suika?status.svg)](https://godocs.io/marisa.chaotic.ninja/suika)
+
+A user-friendly IRC bouncer. Hard-fork of the 0.3 series of [soju](https://soju.im), named after [Suika Ibuki](https://en.touhouwiki.net/wiki/Suika_Ibuki) from [Touhou 7.5: Immaterial and Missing Power](https://en.touhouwiki.net/wiki/Immaterial_and_Missing_Power)
+
+- Multi-user
+- Support multiple clients for a single user, with proper backlog
+ synchronization
+- Support connecting to multiple upstream servers via a single IRC connection
+ to the bouncer
+
+## Building and installing
+
+Dependencies:
+
+- Go
+- BSD or GNU make
+
+For end users, a `Makefile` is provided:
+
+ make
+ doas make install
+
+For development, you can use `go run ./cmd/suika` as usual.
+
+## License
+AGPLv3, see [LICENSE](LICENSE).
+
+* Copyright (C) 2020 The soju Contributors
+* Copyright (C) 2023-present Izuru Yakumo
+
+The code for `version.go` is stolen verbatim from one of [@prologic](https://git.mills.io/prologic)'s projects. It's probably under MIT
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gopkg.in/irc.v3"
+)
+
+func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) {
+ if !ch.complete {
+ panic("Tried to forward a partial channel")
+ }
+
+ // RPL_NOTOPIC shouldn't be sent on JOIN
+ if ch.Topic != "" {
+ sendTopic(dc, ch)
+ }
+
+ if dc.caps["soju.im/read"] {
+ channelCM := ch.conn.network.casemap(ch.Name)
+ r, err := dc.srv.db.GetReadReceipt(ctx, ch.conn.network.ID, channelCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", ch.Name, err)
+ } else {
+ timestampStr := "*"
+ if r != nil {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name), timestampStr},
+ })
+ }
+ }
+
+ sendNames(dc, ch)
+}
+
+func sendTopic(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ if ch.Topic != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_TOPIC,
+ Params: []string{dc.nick, downstreamName, ch.Topic},
+ })
+ if ch.TopicWho != nil {
+ topicWho := dc.marshalUserPrefix(ch.conn.network, ch.TopicWho)
+ topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{dc.nick, downstreamName, topicWho.String(), topicTime},
+ })
+ }
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NOTOPIC,
+ Params: []string{dc.nick, downstreamName, "No topic is set"},
+ })
+ }
+}
+
+func sendNames(dc *downstreamConn, ch *upstreamChannel) {
+ downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
+
+ emptyNameReply := &irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, ""},
+ }
+ maxLength := maxMessageLength - len(emptyNameReply.String())
+
+ var buf strings.Builder
+ for _, entry := range ch.Members.innerMap {
+ nick := entry.originalKey
+ memberships := entry.value.(*memberships)
+ s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick)
+
+ n := buf.Len() + 1 + len(s)
+ if buf.Len() != 0 && n > maxLength {
+ // There's not enough space for the next space + nick.
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ buf.Reset()
+ }
+
+ if buf.Len() != 0 {
+ buf.WriteByte(' ')
+ }
+ buf.WriteString(s)
+ }
+
+ if buf.Len() != 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, string(ch.Status), downstreamName, buf.String()},
+ })
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, downstreamName, "End of /NAMES list"},
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/ed25519"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "math/big"
+ "time"
+)
+
+func generateCertFP(keyType string, bits int) (privKeyBytes, certBytes []byte, err error) {
+ var (
+ privKey crypto.PrivateKey
+ pubKey crypto.PublicKey
+ )
+ switch keyType {
+ case "rsa":
+ key, err := rsa.GenerateKey(rand.Reader, bits)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ecdsa":
+ key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ privKey = key
+ pubKey = key.Public()
+ case "ed25519":
+ var err error
+ pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ // Using PKCS#8 allows easier extension for new key types.
+ privKeyBytes, err = x509.MarshalPKCS8PrivateKey(privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ notBefore := time.Now()
+ // Lets make a fair assumption nobody will use the same cert for more than 20 years...
+ notAfter := notBefore.Add(24 * time.Hour * 365 * 20)
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
+ if err != nil {
+ return nil, nil, err
+ }
+ cert := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{CommonName: "suika auto-generated certificate"},
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
+ certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return privKeyBytes, certBytes, nil
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/url"
+ "os"
+ "strings"
+ "unicode"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+const usage = `usage: suika-znc-import [options...] <znc config path>
+
+Imports configuration from a ZNC file. Users and networks are merged if they
+already exist in the suika database. ZNC settings overwrite existing suika
+settings.
+
+Options:
+
+ -help Show this help message
+ -config <path> Path to suika config file
+ -user <username> Limit import to username (may be specified multiple times)
+ -network <name> Limit import to network (may be specified multiple times)
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ users := make(map[string]bool)
+ networks := make(map[string]bool)
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Var((*stringSetFlag)(&users), "user", "")
+ flag.Var((*stringSetFlag)(&networks), "network", "")
+ flag.Parse()
+
+ zncPath := flag.Arg(0)
+ if zncPath == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ ctx := context.Background()
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+ defer db.Close()
+
+ f, err := os.Open(zncPath)
+ if err != nil {
+ log.Fatalf("failed to open ZNC configuration file: %v", err)
+ }
+ defer f.Close()
+
+ zp := zncParser{bufio.NewReader(f), 1}
+ root, err := zp.sectionBody("", "")
+ if err != nil {
+ log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
+ }
+
+ l, err := db.ListUsers(ctx)
+ if err != nil {
+ log.Fatalf("failed to list users in DB: %v", err)
+ }
+ existingUsers := make(map[string]*suika.User, len(l))
+ for i, u := range l {
+ existingUsers[u.Username] = &l[i]
+ }
+
+ usersCreated := 0
+ usersImported := 0
+ networksImported := 0
+ channelsImported := 0
+ root.ForEach("User", func(section *zncSection) {
+ username := section.Name
+ if len(users) > 0 && !users[username] {
+ return
+ }
+ usersImported++
+
+ u, ok := existingUsers[username]
+ if ok {
+ log.Printf("user %q: updating existing user", username)
+ } else {
+ // "!!" is an invalid crypt format, thus disables password auth
+ u = &suika.User{Username: username, Password: "!!"}
+ usersCreated++
+ log.Printf("user %q: creating new user", username)
+ }
+
+ u.Admin = section.Values.Get("Admin") == "true"
+
+ if err := db.StoreUser(ctx, u); err != nil {
+ log.Fatalf("failed to store user %q: %v", username, err)
+ }
+ userID := u.ID
+
+ l, err := db.ListNetworks(ctx, userID)
+ if err != nil {
+ log.Fatalf("failed to list networks for user %q: %v", username, err)
+ }
+ existingNetworks := make(map[string]*suika.Network, len(l))
+ for i, n := range l {
+ existingNetworks[n.GetName()] = &l[i]
+ }
+
+ nick := section.Values.Get("Nick")
+ realname := section.Values.Get("RealName")
+ ident := section.Values.Get("Ident")
+
+ section.ForEach("Network", func(section *zncSection) {
+ netName := section.Name
+ if len(networks) > 0 && !networks[netName] {
+ return
+ }
+ networksImported++
+
+ logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
+ logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
+
+ netNick := section.Values.Get("Nick")
+ if netNick == "" {
+ netNick = nick
+ }
+ netRealname := section.Values.Get("RealName")
+ if netRealname == "" {
+ netRealname = realname
+ }
+ netIdent := section.Values.Get("Ident")
+ if netIdent == "" {
+ netIdent = ident
+ }
+
+ for _, name := range section.Values["LoadModule"] {
+ switch name {
+ case "sasl":
+ logger.Printf("warning: SASL credentials not imported")
+ case "nickserv":
+ logger.Printf("warning: NickServ credentials not imported")
+ case "perform":
+ logger.Printf("warning: \"perform\" plugin commands not imported")
+ }
+ }
+
+ u, pass, err := importNetworkServer(section.Values.Get("Server"))
+ if err != nil {
+ logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
+ }
+
+ n, ok := existingNetworks[netName]
+ if ok {
+ logger.Printf("updating existing network")
+ } else {
+ n = &suika.Network{Name: netName}
+ logger.Printf("creating new network")
+ }
+
+ n.Addr = u.String()
+ n.Nick = netNick
+ n.Username = netIdent
+ n.Realname = netRealname
+ n.Pass = pass
+ n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
+
+ if err := db.StoreNetwork(ctx, userID, n); err != nil {
+ logger.Fatalf("failed to store network: %v", err)
+ }
+
+ l, err := db.ListChannels(ctx, n.ID)
+ if err != nil {
+ logger.Fatalf("failed to list channels: %v", err)
+ }
+ existingChannels := make(map[string]*suika.Channel, len(l))
+ for i, ch := range l {
+ existingChannels[ch.Name] = &l[i]
+ }
+
+ section.ForEach("Chan", func(section *zncSection) {
+ chName := section.Name
+
+ if section.Values.Get("Disabled") == "true" {
+ logger.Printf("skipping import of disabled channel %q", chName)
+ return
+ }
+
+ channelsImported++
+
+ ch, ok := existingChannels[chName]
+ if ok {
+ logger.Printf("channel %q: updating existing channel", chName)
+ } else {
+ ch = &suika.Channel{Name: chName}
+ logger.Printf("channel %q: creating new channel", chName)
+ }
+
+ ch.Key = section.Values.Get("Key")
+ ch.Detached = section.Values.Get("Detached") == "true"
+
+ if err := db.StoreChannel(ctx, n.ID, ch); err != nil {
+ logger.Printf("channel %q: failed to store channel: %v", chName, err)
+ }
+ })
+ })
+ })
+
+ if err := db.Close(); err != nil {
+ log.Printf("failed to close database: %v", err)
+ }
+
+ if usersCreated > 0 {
+ log.Printf("warning: user passwords haven't been imported, please set them with `suikactl change-password <username>`")
+ }
+
+ log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
+}
+
+func importNetworkServer(s string) (u *url.URL, pass string, err error) {
+ parts := strings.Fields(s)
+ if len(parts) < 2 {
+ return nil, "", fmt.Errorf("expected space-separated host and port")
+ }
+
+ scheme := "irc"
+ host := parts[0]
+ port := parts[1]
+ if strings.HasPrefix(port, "+") {
+ port = port[1:]
+ scheme = "ircs"
+ }
+
+ if len(parts) > 2 {
+ pass = parts[2]
+ }
+
+ u = &url.URL{
+ Scheme: scheme,
+ Host: host + ":" + port,
+ }
+ return u, pass, nil
+}
+
+type zncSection struct {
+ Type string
+ Name string
+ Values zncValues
+ Children []zncSection
+}
+
+func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
+ for _, section := range s.Children {
+ if section.Type == typ {
+ f(§ion)
+ }
+ }
+}
+
+type zncValues map[string][]string
+
+func (zv zncValues) Get(k string) string {
+ if len(zv[k]) == 0 {
+ return ""
+ }
+ return zv[k][0]
+}
+
+type zncParser struct {
+ br *bufio.Reader
+ line int
+}
+
+func (zp *zncParser) readByte() (byte, error) {
+ b, err := zp.br.ReadByte()
+ if b == '\n' {
+ zp.line++
+ }
+ return b, err
+}
+
+func (zp *zncParser) readRune() (rune, int, error) {
+ r, n, err := zp.br.ReadRune()
+ if r == '\n' {
+ zp.line++
+ }
+ return r, n, err
+}
+
+func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
+ section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
+
+Loop:
+ for {
+ if err := zp.skipSpace(); err != nil {
+ return nil, err
+ }
+
+ b, err := zp.br.Peek(2)
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+
+ switch b[0] {
+ case '<':
+ if b[1] == '/' {
+ break Loop
+ } else {
+ childType, childName, err := zp.sectionHeader()
+ if err != nil {
+ return nil, err
+ }
+ child, err := zp.sectionBody(childType, childName)
+ if err != nil {
+ return nil, err
+ }
+ if footerType, err := zp.sectionFooter(); err != nil {
+ return nil, err
+ } else if footerType != childType {
+ return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
+ }
+ section.Children = append(section.Children, *child)
+ }
+ case '/':
+ if b[1] == '/' {
+ if err := zp.skipComment(); err != nil {
+ return nil, err
+ }
+ break
+ }
+ fallthrough
+ default:
+ k, v, err := zp.keyValuePair()
+ if err != nil {
+ return nil, err
+ }
+ section.Values[k] = append(section.Values[k], v)
+ }
+ }
+
+ return section, nil
+}
+
+func (zp *zncParser) skipSpace() error {
+ for {
+ r, _, err := zp.readRune()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if !unicode.IsSpace(r) {
+ zp.br.UnreadRune()
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) skipComment() error {
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return err
+ }
+
+ for {
+ b, err := zp.readByte()
+ if err == io.EOF {
+ return nil
+ } else if err != nil {
+ return err
+ }
+
+ if b == '\n' {
+ return nil
+ }
+ }
+}
+
+func (zp *zncParser) sectionHeader() (string, string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", "", err
+ }
+ typ, err := zp.readWord(' ')
+ if err != nil {
+ return "", "", err
+ }
+ name, err := zp.readWord('>')
+ return typ, name, err
+}
+
+func (zp *zncParser) sectionFooter() (string, error) {
+ if err := zp.expectRune('<'); err != nil {
+ return "", err
+ }
+ if err := zp.expectRune('/'); err != nil {
+ return "", err
+ }
+ return zp.readWord('>')
+}
+
+func (zp *zncParser) keyValuePair() (string, string, error) {
+ k, err := zp.readWord('=')
+ if err != nil {
+ return "", "", err
+ }
+ v, err := zp.readWord('\n')
+ return strings.TrimSpace(k), strings.TrimSpace(v), err
+}
+
+func (zp *zncParser) expectRune(expected rune) error {
+ r, _, err := zp.readRune()
+ if err != nil {
+ return err
+ } else if r != expected {
+ return fmt.Errorf("expected %q, got %q", expected, r)
+ }
+ return nil
+}
+
+func (zp *zncParser) readWord(delim byte) (string, error) {
+ var sb strings.Builder
+ for {
+ b, err := zp.readByte()
+ if err != nil {
+ return "", err
+ }
+
+ if b == delim {
+ return sb.String(), nil
+ }
+ if b == '\n' {
+ return "", fmt.Errorf("expected %q before newline", delim)
+ }
+
+ sb.WriteByte(b)
+ }
+}
+
+type stringSetFlag map[string]bool
+
+func (v *stringSetFlag) String() string {
+ return fmt.Sprint(map[string]bool(*v))
+}
+
+func (v *stringSetFlag) Set(s string) error {
+ (*v)[s] = true
+ return nil
+}
--- /dev/null
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/url"
+ "os"
+ "os/signal"
+ "strings"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+)
+
+// TCP keep-alive interval for downstream TCP connections
+const downstreamKeepAlive = 1 * time.Hour
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+func bumpOpenedFileLimit() error {
+ var rlimit syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
+ }
+ rlimit.Cur = rlimit.Max
+ if err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rlimit); err != nil {
+ return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
+ }
+ return nil
+}
+
+var (
+ configPath string
+ debug bool
+
+ tlsCert atomic.Value // *tls.Certificate
+)
+
+func loadConfig() (*config.Server, *suika.Config, error) {
+ var raw *config.Server
+ if configPath != "" {
+ var err error
+ raw, err = config.Load(configPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load config file: %v", err)
+ }
+ } else {
+ raw = config.Defaults()
+ }
+
+ var motd string
+ if raw.MOTDPath != "" {
+ b, err := os.ReadFile(raw.MOTDPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load MOTD: %v", err)
+ }
+ motd = strings.TrimSuffix(string(b), "\n")
+ }
+
+ if raw.TLS != nil {
+ cert, err := tls.LoadX509KeyPair(raw.TLS.CertPath, raw.TLS.KeyPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load TLS certificate and key: %v", err)
+ }
+ tlsCert.Store(&cert)
+ }
+
+ cfg := &suika.Config{
+ Hostname: raw.Hostname,
+ Title: raw.Title,
+ LogPath: raw.LogPath,
+ MaxUserNetworks: raw.MaxUserNetworks,
+ MultiUpstream: raw.MultiUpstream,
+ UpstreamUserIPs: raw.UpstreamUserIPs,
+ MOTD: motd,
+ }
+ return raw, cfg, nil
+}
+
+func main() {
+ var listen []string
+ flag.Var((*stringSliceFlag)(&listen), "listen", "listening address")
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.BoolVar(&debug, "debug", false, "enable debug logging")
+ flag.Parse()
+
+ cfg, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ cfg.Listen = append(cfg.Listen, listen...)
+ if len(cfg.Listen) == 0 {
+ cfg.Listen = []string{":6667"}
+ }
+
+ if err := bumpOpenedFileLimit(); err != nil {
+ log.Printf("failed to bump max number of opened files: %v", err)
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ var tlsCfg *tls.Config
+ if cfg.TLS != nil {
+ tlsCfg = &tls.Config{
+ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return tlsCert.Load().(*tls.Certificate), nil
+ },
+ }
+ }
+
+ srv := suika.NewServer(db)
+ srv.SetConfig(serverCfg)
+ srv.Logger = suika.NewLogger(log.Writer(), debug)
+
+ for _, listen := range cfg.Listen {
+ listen := listen // copy
+ listenURI := listen
+ if !strings.Contains(listenURI, ":/") {
+ // This is a raw domain name, make it an URL with an empty scheme
+ listenURI = "//" + listenURI
+ }
+ u, err := url.Parse(listenURI)
+ if err != nil {
+ log.Fatalf("failed to parse listen URI %q: %v", listen, err)
+ }
+
+ switch u.Scheme {
+ case "ircs", "":
+ if tlsCfg == nil {
+ log.Fatalf("failed to listen on %q: missing TLS configuration", listen)
+ }
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6697"
+ }
+ ircsTLSCfg := tlsCfg.Clone()
+ ircsTLSCfg.NextProtos = []string{"irc"}
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ l, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start TLS listener on %q: %v", listen, err)
+ }
+ ln := tls.NewListener(l, ircsTLSCfg)
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "irc":
+ host := u.Host
+ if _, _, err := net.SplitHostPort(host); err != nil {
+ host = host + ":6667"
+ }
+ lc := net.ListenConfig{
+ KeepAlive: downstreamKeepAlive,
+ }
+ ln, err := lc.Listen(context.Background(), "tcp", host)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ case "unix":
+ ln, err := net.Listen("unix", u.Path)
+ if err != nil {
+ log.Fatalf("failed to start listener on %q: %v", listen, err)
+ }
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ log.Printf("serving %q: %v", listen, err)
+ }
+ }()
+ default:
+ log.Fatalf("failed to listen on %q: unsupported scheme", listen)
+ }
+
+ log.Printf("starting suika version %v\n", suika.FullVersion())
+ log.Printf("server listening on %q", listen)
+ }
+
+ sigCh := make(chan os.Signal, 1)
+ signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
+
+ if err := srv.Start(); err != nil {
+ log.Fatal(err)
+ }
+
+ for sig := range sigCh {
+ switch sig {
+ case syscall.SIGHUP:
+ log.Print("reloading configuration")
+ _, serverCfg, err := loadConfig()
+ if err != nil {
+ log.Printf("failed to reloading configuration: %v", err)
+ } else {
+ srv.SetConfig(serverCfg)
+ }
+ case syscall.SIGINT, syscall.SIGTERM:
+ log.Print("shutting down server")
+ srv.Shutdown()
+ return
+ }
+ }
+}
--- /dev/null
+package main
+
+import (
+ "bufio"
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+
+ "marisa.chaotic.ninja/suika"
+ "marisa.chaotic.ninja/suika/config"
+ "golang.org/x/crypto/bcrypt"
+ "golang.org/x/term"
+)
+
+const usage = `usage: suikadb [-config path] <action> [options...]
+
+ create-user <username> [-admin] Create a new user
+ change-password <username> Change password for a user
+ help Show this help message
+`
+
+func init() {
+ flag.Usage = func() {
+ fmt.Fprintf(flag.CommandLine.Output(), usage)
+ }
+}
+
+func main() {
+ var configPath string
+ flag.StringVar(&configPath, "config", "", "path to configuration file")
+ flag.Parse()
+
+ var cfg *config.Server
+ if configPath != "" {
+ var err error
+ cfg, err = config.Load(configPath)
+ if err != nil {
+ log.Fatalf("failed to load config file: %v", err)
+ }
+ } else {
+ cfg = config.Defaults()
+ }
+
+ db, err := suika.OpenDB(cfg.SQLDriver, cfg.SQLSource)
+ if err != nil {
+ log.Fatalf("failed to open database: %v", err)
+ }
+
+ ctx := context.Background()
+
+ switch cmd := flag.Arg(0); cmd {
+ case "create-user":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ fs := flag.NewFlagSet("", flag.ExitOnError)
+ admin := fs.Bool("admin", false, "make the new user admin")
+ fs.Parse(flag.Args()[2:])
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user := suika.User{
+ Username: username,
+ Password: string(hashed),
+ Admin: *admin,
+ }
+ if err := db.StoreUser(ctx, &user); err != nil {
+ log.Fatalf("failed to create user: %v", err)
+ }
+ case "change-password":
+ username := flag.Arg(1)
+ if username == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ user, err := db.GetUser(ctx, username)
+ if err != nil {
+ log.Fatalf("failed to get user: %v", err)
+ }
+
+ password, err := readPassword()
+ if err != nil {
+ log.Fatalf("failed to read password: %v", err)
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword(password, bcrypt.DefaultCost)
+ if err != nil {
+ log.Fatalf("failed to hash password: %v", err)
+ }
+
+ user.Password = string(hashed)
+ if err := db.StoreUser(ctx, user); err != nil {
+ log.Fatalf("failed to update password: %v", err)
+ }
+ case "version":
+ fmt.Printf("%v\n", suika.FullVersion())
+ default:
+ flag.Usage()
+ if cmd != "help" {
+ os.Exit(1)
+ }
+ }
+}
+
+func readPassword() ([]byte, error) {
+ var password []byte
+ var err error
+ fd := int(os.Stdin.Fd())
+
+ if term.IsTerminal(fd) {
+ fmt.Printf("Password: ")
+ password, err = term.ReadPassword(int(os.Stdin.Fd()))
+ if err != nil {
+ return nil, err
+ }
+ fmt.Printf("\n")
+ } else {
+ fmt.Fprintf(os.Stderr, "Warning: Reading password from stdin.\n")
+ // TODO: the buffering messes up repeated calls to readPassword
+ scanner := bufio.NewScanner(os.Stdin)
+ if !scanner.Scan() {
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ return nil, io.ErrUnexpectedEOF
+ }
+ password = scanner.Bytes()
+
+ if len(password) == 0 {
+ return nil, fmt.Errorf("zero length password")
+ }
+ }
+
+ return password, nil
+}
--- /dev/null
+db sqlite3 /var/lib/suika/main.db
+log fs /var/lib/suika/logs/
--- /dev/null
+package config
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+
+ "git.sr.ht/~emersion/go-scfg"
+)
+
+type TLS struct {
+ CertPath, KeyPath string
+}
+
+type Server struct {
+ Listen []string
+ TLS *TLS
+ Hostname string
+ Title string
+ MOTDPath string
+
+ SQLDriver string
+ SQLSource string
+ LogPath string
+
+ MaxUserNetworks int
+ MultiUpstream bool
+ UpstreamUserIPs []*net.IPNet
+}
+
+func Defaults() *Server {
+ hostname, err := os.Hostname()
+ if err != nil {
+ hostname = "localhost"
+ }
+ return &Server{
+ Hostname: hostname,
+ SQLDriver: "sqlite3",
+ SQLSource: "suika.db",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ }
+}
+
+func Load(path string) (*Server, error) {
+ cfg, err := scfg.Load(path)
+ if err != nil {
+ return nil, err
+ }
+ return parse(cfg)
+}
+
+func parse(cfg scfg.Block) (*Server, error) {
+ srv := Defaults()
+ for _, d := range cfg {
+ switch d.Name {
+ case "listen":
+ var uri string
+ if err := d.ParseParams(&uri); err != nil {
+ return nil, err
+ }
+ srv.Listen = append(srv.Listen, uri)
+ case "hostname":
+ if err := d.ParseParams(&srv.Hostname); err != nil {
+ return nil, err
+ }
+ case "title":
+ if err := d.ParseParams(&srv.Title); err != nil {
+ return nil, err
+ }
+ case "motd":
+ if err := d.ParseParams(&srv.MOTDPath); err != nil {
+ return nil, err
+ }
+ case "tls":
+ tls := &TLS{}
+ if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
+ return nil, err
+ }
+ srv.TLS = tls
+ case "db":
+ if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
+ return nil, err
+ }
+ case "log":
+ var driver string
+ if err := d.ParseParams(&driver, &srv.LogPath); err != nil {
+ return nil, err
+ }
+ if driver != "fs" {
+ return nil, fmt.Errorf("directive %q: unknown driver %q", d.Name, driver)
+ }
+ case "max-user-networks":
+ var max string
+ if err := d.ParseParams(&max); err != nil {
+ return nil, err
+ }
+ var err error
+ if srv.MaxUserNetworks, err = strconv.Atoi(max); err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ case "multi-upstream-mode":
+ var str string
+ if err := d.ParseParams(&str); err != nil {
+ return nil, err
+ }
+ v, err := strconv.ParseBool(str)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: %v", d.Name, err)
+ }
+ srv.MultiUpstream = v
+ case "upstream-user-ip":
+ if len(srv.UpstreamUserIPs) > 0 {
+ return nil, fmt.Errorf("directive %q: can only be specified once", d.Name)
+ }
+ var hasIPv4, hasIPv6 bool
+ for _, s := range d.Params {
+ _, n, err := net.ParseCIDR(s)
+ if err != nil {
+ return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
+ }
+ if n.IP.To4() == nil {
+ if hasIPv6 {
+ return nil, fmt.Errorf("directive %q: found two IPv6 CIDRs", d.Name)
+ }
+ hasIPv6 = true
+ } else {
+ if hasIPv4 {
+ return nil, fmt.Errorf("directive %q: found two IPv4 CIDRs", d.Name)
+ }
+ hasIPv4 = true
+ }
+ srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
+ }
+ default:
+ return nil, fmt.Errorf("unknown directive %q", d.Name)
+ }
+ }
+
+ return srv, nil
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "golang.org/x/time/rate"
+ "gopkg.in/irc.v3"
+)
+
+// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on
+// reading and writing IRC messages.
+type ircConn interface {
+ ReadMessage() (*irc.Message, error)
+ WriteMessage(*irc.Message) error
+ Close() error
+ SetReadDeadline(time.Time) error
+ SetWriteDeadline(time.Time) error
+ RemoteAddr() net.Addr
+ LocalAddr() net.Addr
+}
+
+func newNetIRCConn(c net.Conn) ircConn {
+ type netConn net.Conn
+ return struct {
+ *irc.Conn
+ netConn
+ }{irc.NewConn(c), c}
+}
+
+type connOptions struct {
+ Logger Logger
+ RateLimitDelay time.Duration
+ RateLimitBurst int
+}
+
+type conn struct {
+ conn ircConn
+ srv *Server
+ logger Logger
+
+ lock sync.Mutex
+ outgoing chan<- *irc.Message
+ closed bool
+ closedCh chan struct{}
+}
+
+func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
+ outgoing := make(chan *irc.Message, 64)
+ c := &conn{
+ conn: ic,
+ srv: srv,
+ outgoing: outgoing,
+ logger: options.Logger,
+ closedCh: make(chan struct{}),
+ }
+
+ go func() {
+ ctx, cancel := c.NewContext(context.Background())
+ defer cancel()
+
+ rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
+ for msg := range outgoing {
+ if err := rl.Wait(ctx); err != nil {
+ break
+ }
+
+ c.logger.Debugf("sent: %v", msg)
+ c.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
+ if err := c.conn.WriteMessage(msg); err != nil {
+ c.logger.Printf("failed to write message: %v", err)
+ break
+ }
+ }
+ if err := c.conn.Close(); err != nil && !isErrClosed(err) {
+ c.logger.Printf("failed to close connection: %v", err)
+ } else {
+ c.logger.Debugf("connection closed")
+ }
+ // Drain the outgoing channel to prevent SendMessage from blocking
+ for range outgoing {
+ // This space is intentionally left blank
+ }
+ }()
+
+ c.logger.Debugf("new connection")
+ return c
+}
+
+func (c *conn) isClosed() bool {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ return c.closed
+}
+
+// Close closes the connection. It is safe to call from any goroutine.
+func (c *conn) Close() error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return fmt.Errorf("connection already closed")
+ }
+
+ err := c.conn.Close()
+ c.closed = true
+ close(c.outgoing)
+ close(c.closedCh)
+ return err
+}
+
+func (c *conn) ReadMessage() (*irc.Message, error) {
+ msg, err := c.conn.ReadMessage()
+ if isErrClosed(err) {
+ return nil, io.EOF
+ } else if err != nil {
+ return nil, err
+ }
+
+ c.logger.Debugf("received: %v", msg)
+ return msg, nil
+}
+
+// SendMessage queues a new outgoing message. It is safe to call from any
+// goroutine.
+//
+// If the connection is closed before the message is sent, SendMessage silently
+// drops the message.
+func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.closed {
+ return
+ }
+
+ select {
+ case c.outgoing <- msg:
+ // Success
+ case <-ctx.Done():
+ c.logger.Printf("failed to send message: %v", ctx.Err())
+ }
+}
+
+func (c *conn) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+
+func (c *conn) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+// NewContext returns a copy of the parent context with a new Done channel. The
+// returned context's Done channel is closed when the connection is closed,
+// when the returned cancel function is called, or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// Canceling this context releases resources associated with it, so code should
+// call cancel as soon as the operations running in this Context complete.
+func (c *conn) NewContext(parent context.Context) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(parent)
+
+ go func() {
+ defer cancel()
+
+ select {
+ case <-ctx.Done():
+ // The parent context has been cancelled, or the caller has called
+ // cancel()
+ case <-c.closedCh:
+ // The connection has been closed
+ }
+ }()
+
+ return ctx, cancel
+}
--- /dev/null
+#!/bin/sh -eu
+
+# Converts a log dir to its case-mapped form.
+#
+# suika needs to be stopped for this script to work properly. The script may
+# re-order messages that happened within the same second interval if merging
+# two daily log files is necessary.
+#
+# usage: casemap-logs.sh <directory>
+
+root="$1"
+
+for net_dir in "$root"/*/*; do
+ for chan in $(ls "$net_dir"); do
+ cm_chan="$(echo $chan | tr '[:upper:]' '[:lower:]')"
+ if [ "$chan" = "$cm_chan" ]; then
+ continue
+ fi
+
+ if ! [ -d "$net_dir/$cm_chan" ]; then
+ echo >&2 "Moving case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ mv "$net_dir/$chan" "$net_dir/$cm_chan"
+ continue
+ fi
+
+ echo "Merging case-mapped channel dir: '$net_dir/$chan' -> '$cm_chan'"
+ for day in $(ls "$net_dir/$chan"); do
+ if ! [ -e "$net_dir/$cm_chan/$day" ]; then
+ echo >&2 " Moving log file: '$day'"
+ mv "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day"
+ continue
+ fi
+
+ echo >&2 " Merging log file: '$day'"
+ sort "$net_dir/$chan/$day" "$net_dir/$cm_chan/$day" >"$net_dir/$cm_chan/$day.new"
+ mv "$net_dir/$cm_chan/$day.new" "$net_dir/$cm_chan/$day"
+ rm "$net_dir/$chan/$day"
+ done
+
+ rmdir "$net_dir/$chan"
+ done
+done
--- /dev/null
+# Clients
+
+This page describes how to configure IRC clients to better integrate with soju.
+
+Also see the [IRCv3 support tables] for a more general list of clients.
+
+# catgirl
+
+catgirl doesn't properly implement cap-3.2, so many capabilities will be
+disabled. catgirl developers have publicly stated that supporting bouncers such
+as soju is a non-goal.
+
+# [Emacs]
+
+There are two clients provided with Emacs. They require some setup to work
+properly.
+
+## Erc
+
+You need to explicitly set the username, which is the defcustom
+`erc-email-userid`.
+
+```elisp
+(setq erc-email-userid "<username>/irc.libera.chat") ;; Example with Libera.Chat
+(defun run-erc ()
+ (interactive)
+ (erc-tls :server "<server>"
+ :port 6697
+ :nick "<nick>"
+ :password "<password>"))
+```
+
+Then run `M-x run-erc`.
+
+## Rcirc
+
+The only thing needed here is the general config:
+
+```elisp
+(setq rcirc-server-alist
+ '(("<server>"
+ :port 6697
+ :encryption tls
+ :nick "<nick>"
+ :user-name "<username>/irc.libera.chat" ;; Example with Libera.Chat
+ :password "<password>")))
+```
+
+Then run `M-x irc`.
+
+# [gamja]
+
+gamja has been designed together with soju, so should have excellent
+integration. gamja supports many IRCv3 features including chat history.
+gamja also provides UI to manage soju networks via the
+`soju.im/bouncer-networks` extension.
+
+# [goguma]
+
+Much like gamja, goguma has been designed together with soju, so should have
+excellent integration. goguma supports many IRCv3 features including chat
+history. goguma should seamlessly connect to all networks configured in soju via
+the `soju.im/bouncer-networks` extension.
+
+# [Hexchat]
+
+Hexchat has support for a small set of IRCv3 capabilities. To prevent
+automatically reconnecting to channels parted from soju, and prevent buffering
+outgoing messages:
+
+ /set irc_reconnect_rejoin off
+ /set net_throttle off
+
+# [senpai]
+
+senpai is being developed with soju in mind, so should have excellent
+integration. senpai supports many IRCv3 features including chat history.
+
+# [Weechat]
+
+A [Weechat script] is available to provide better integration with soju.
+The script will automatically connect to all of your networks once a
+single connection to soju is set up in Weechat.
+
+On WeeChat 3.2-, no IRCv3 capabilities are enabled by default. To enable them:
+
+ /set irc.server_default.capabilities account-notify,away-notify,cap-notify,chghost,extended-join,invite-notify,multi-prefix,server-time,userhost-in-names
+ /save
+ /reconnect -all
+
+See `/help cap` for more information.
+
+[IRCv3 support tables]: https://ircv3.net/software/clients
+[gamja]: https://sr.ht/~emersion/gamja/
+[goguma]: https://sr.ht/~emersion/goguma/
+[senpai]: https://sr.ht/~taiite/senpai/
+[Weechat]: https://weechat.org/
+[Weechat script]: https://github.com/weechat/scripts/blob/master/python/soju.py
+[Hexchat]: https://hexchat.github.io/
+[Emacs]: https://www.gnu.org/software/emacs/
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+ "time"
+)
+
+type Database interface {
+ Close() error
+ Stats(ctx context.Context) (*DatabaseStats, error)
+
+ ListUsers(ctx context.Context) ([]User, error)
+ GetUser(ctx context.Context, username string) (*User, error)
+ StoreUser(ctx context.Context, user *User) error
+ DeleteUser(ctx context.Context, id int64) error
+
+ ListNetworks(ctx context.Context, userID int64) ([]Network, error)
+ StoreNetwork(ctx context.Context, userID int64, network *Network) error
+ DeleteNetwork(ctx context.Context, id int64) error
+ ListChannels(ctx context.Context, networkID int64) ([]Channel, error)
+ StoreChannel(ctx context.Context, networKID int64, ch *Channel) error
+ DeleteChannel(ctx context.Context, id int64) error
+
+ ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error)
+ StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error
+
+ GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error)
+ StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error
+}
+
+func OpenDB(driver, source string) (Database, error) {
+ switch driver {
+ case "sqlite3":
+ return OpenSqliteDB(source)
+ case "postgres":
+ return OpenPostgresDB(source)
+ default:
+ return nil, fmt.Errorf("unsupported database driver: %q", driver)
+ }
+}
+
+type DatabaseStats struct {
+ Users int64
+ Networks int64
+ Channels int64
+}
+
+type User struct {
+ ID int64
+ Username string
+ Password string // hashed
+ Realname string
+ Admin bool
+}
+
+type SASL struct {
+ Mechanism string
+
+ Plain struct {
+ Username string
+ Password string
+ }
+
+ // TLS client certificate authentication.
+ External struct {
+ // X.509 certificate in DER form.
+ CertBlob []byte
+ // PKCS#8 private key in DER form.
+ PrivKeyBlob []byte
+ }
+}
+
+type Network struct {
+ ID int64
+ Name string
+ Addr string
+ Nick string
+ Username string
+ Realname string
+ Pass string
+ ConnectCommands []string
+ SASL SASL
+ Enabled bool
+}
+
+func (net *Network) GetName() string {
+ if net.Name != "" {
+ return net.Name
+ }
+ return net.Addr
+}
+
+func (net *Network) URL() (*url.URL, error) {
+ s := net.Addr
+ if !strings.Contains(s, "://") {
+ // This is a raw domain name, make it an URL with the default scheme
+ s = "ircs://" + s
+ }
+
+ u, err := url.Parse(s)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
+ }
+
+ return u, nil
+}
+
+func GetNick(user *User, net *Network) string {
+ if net.Nick != "" {
+ return net.Nick
+ }
+ return user.Username
+}
+
+func GetUsername(user *User, net *Network) string {
+ if net.Username != "" {
+ return net.Username
+ }
+ return GetNick(user, net)
+}
+
+func GetRealname(user *User, net *Network) string {
+ if net.Realname != "" {
+ return net.Realname
+ }
+ if user.Realname != "" {
+ return user.Realname
+ }
+ return GetNick(user, net)
+}
+
+type MessageFilter int
+
+const (
+ // TODO: use customizable user defaults for FilterDefault
+ FilterDefault MessageFilter = iota
+ FilterNone
+ FilterHighlight
+ FilterMessage
+)
+
+func parseFilter(filter string) (MessageFilter, error) {
+ switch filter {
+ case "default":
+ return FilterDefault, nil
+ case "none":
+ return FilterNone, nil
+ case "highlight":
+ return FilterHighlight, nil
+ case "message":
+ return FilterMessage, nil
+ }
+ return 0, fmt.Errorf("unknown filter: %q", filter)
+}
+
+type Channel struct {
+ ID int64
+ Name string
+ Key string
+
+ Detached bool
+ DetachedInternalMsgID string
+
+ RelayDetached MessageFilter
+ ReattachOn MessageFilter
+ DetachAfter time.Duration
+ DetachOn MessageFilter
+}
+
+type DeliveryReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Client string
+ InternalMsgID string
+}
+
+type ReadReceipt struct {
+ ID int64
+ Target string // channel or nick
+ Timestamp time.Time
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "errors"
+ "fmt"
+ "math"
+ "strings"
+ "time"
+
+ _ "github.com/lib/pq"
+)
+
+const postgresQueryTimeout = 5 * time.Second
+
+const postgresConfigSchema = `
+CREATE TABLE IF NOT EXISTS "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+`
+//go:embed suika_psql_schema.sql
+var postgresSchema string
+
+var postgresMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ `ALTER TABLE "Network" ALTER COLUMN nick DROP NOT NULL`,
+ `
+ CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+ ALTER TABLE "Network"
+ ALTER COLUMN sasl_mechanism
+ TYPE sasl_mechanism
+ USING sasl_mechanism::sasl_mechanism;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type PostgresDB struct {
+ db *sql.DB
+}
+
+func OpenPostgresDB(source string) (Database, error) {
+ sqlPostgresDB, err := sql.Open("postgres", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &PostgresDB{db: sqlPostgresDB}
+ if err := db.upgrade(); err != nil {
+ sqlPostgresDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *PostgresDB) upgrade() error {
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if _, err := tx.Exec(postgresConfigSchema); err != nil {
+ return fmt.Errorf("failed to create Config table: %s", err)
+ }
+
+ var version int
+ err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return fmt.Errorf("failed to query schema version: %s", err)
+ }
+
+ if version == len(postgresMigrations) {
+ return nil
+ }
+ if version > len(postgresMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(postgresMigrations), version)
+ }
+
+ if version == 0 {
+ if _, err := tx.Exec(postgresSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %s", err)
+ }
+ } else {
+ for i := version; i < len(postgresMigrations); i++ {
+ if _, err := tx.Exec(postgresMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ _, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
+ ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) Close() error {
+ return db.db.Close()
+}
+
+func (db *PostgresDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM "User") AS users,
+ (SELECT COUNT(*) FROM "Network") AS networks,
+ (SELECT COUNT(*) FROM "Channel") AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ `SELECT id, username, password, admin, realname FROM "User"`)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, password, admin, realname FROM "User" WHERE username = $1`,
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ password := toNullString(user.Password)
+ realname := toNullString(user.Realname)
+
+ var err error
+ if user.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "User" (username, password, admin, realname)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`,
+ user.Username, password, user.Admin, realname).Scan(&user.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "User"
+ SET password = $1, admin = $2, realname = $3
+ WHERE id = $4`,
+ password, user.Admin, realname, user.ID)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteUser(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "User" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism,
+ sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, enabled
+ FROM "Network"
+ WHERE "user" = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ netName := toNullString(network.Name)
+ nick := toNullString(network.Nick)
+ netUsername := toNullString(network.Username)
+ realname := toNullString(network.Realname)
+ pass := toNullString(network.Pass)
+ connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ var err error
+ if network.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
+ sasl_external_key, enabled)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ RETURNING id`,
+ userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled).Scan(&network.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Network"
+ SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7,
+ connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10,
+ sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13,
+ enabled = $14
+ WHERE id = $1`,
+ network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands,
+ saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
+ network.SASL.External.PrivKeyBlob, network.Enabled)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteNetwork(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Network" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after,
+ detach_on
+ FROM "Channel"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *PostgresDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ key := toNullString(ch.Key)
+ detachAfter := int64(math.Ceil(ch.DetachAfter.Seconds()))
+
+ var err error
+ if ch.ID == 0 {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "Channel" (network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on,
+ detach_after, detach_on)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+ RETURNING id`,
+ networkID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn).Scan(&ch.ID)
+ } else {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "Channel"
+ SET name = $2, key = $3, detached = $4, detached_internal_msgid = $5,
+ relay_detached = $6, reattach_on = $7, detach_after = $8, detach_on = $9
+ WHERE id = $1`,
+ ch.ID, ch.Name, key, ch.Detached, toNullString(ch.DetachedInternalMsgID),
+ ch.RelayDetached, ch.ReattachOn, detachAfter, ch.DetachOn)
+ }
+ return err
+}
+
+func (db *PostgresDB) DeleteChannel(ctx context.Context, id int64) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, `DELETE FROM "Channel" WHERE id = $1`, id)
+ return err
+}
+
+func (db *PostgresDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM "DeliveryReceipt"
+ WHERE network = $1`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &rcpt.Client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *PostgresDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx,
+ `DELETE FROM "DeliveryReceipt" WHERE network = $1 AND client = $2`,
+ networkID, client)
+ if err != nil {
+ return err
+ }
+
+ stmt, err := tx.PrepareContext(ctx, `
+ INSERT INTO "DeliveryReceipt" (network, target, client, internal_msgid)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id`)
+ if err != nil {
+ return err
+ }
+ defer stmt.Close()
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+ err := stmt.
+ QueryRowContext(ctx, networkID, rcpt.Target, client, rcpt.InternalMsgID).
+ Scan(&rcpt.ID)
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *PostgresDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx,
+ `SELECT id, timestamp FROM "ReadReceipt" WHERE network = $1 AND target = $2`,
+ networkID, name)
+ if err := row.Scan(&receipt.ID, &receipt.Timestamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ return receipt, nil
+}
+
+func (db *PostgresDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
+ defer cancel()
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE "ReadReceipt"
+ SET timestamp = $1
+ WHERE id = $2`,
+ receipt.Timestamp, receipt.ID)
+ } else {
+ err = db.db.QueryRowContext(ctx, `
+ INSERT INTO "ReadReceipt" (network, target, timestamp)
+ VALUES ($1, $2, $3)
+ RETURNING id`,
+ networkID, receipt.Target, receipt.Timestamp).Scan(&receipt.ID)
+ }
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "os"
+ "testing"
+)
+
+// PostgreSQL version 0 schema. DO NOT EDIT.
+const postgresV0Schema = `
+CREATE TABLE "Config" (
+ id SMALLINT PRIMARY KEY,
+ version INTEGER NOT NULL,
+ CHECK(id = 1)
+);
+
+INSERT INTO "Config" (id, version) VALUES (1, 1);
+
+CREATE TABLE "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TABLE "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA DEFAULT NULL,
+ sasl_external_key BYTEA DEFAULT NULL,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+
+CREATE TABLE "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+
+CREATE TABLE "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+`
+
+func openTempPostgresDB(t *testing.T) *sql.DB {
+ source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
+ if !ok {
+ t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
+ }
+
+ db, err := sql.Open("postgres", source)
+ if err != nil {
+ t.Fatalf("failed to connect to PostgreSQL: %v", err)
+ }
+
+ // Store all tables in a temporary schema which will be dropped when the
+ // connection to PostgreSQL is closed.
+ db.SetMaxOpenConns(1)
+ if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
+ t.Fatalf("failed to set PostgreSQL search_path: %v", err)
+ }
+
+ return db
+}
+
+func TestPostgresMigrations(t *testing.T) {
+ sqlDB := openTempPostgresDB(t)
+ if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &PostgresDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("PostgresDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "database/sql"
+ _ "embed"
+ "fmt"
+ "math"
+ "strings"
+ "sync"
+ "time"
+
+ _ "modernc.org/sqlite"
+)
+
+const sqliteQueryTimeout = 5 * time.Second
+
+//go:embed suika_sqlite_schema.sql
+var sqliteSchema string
+
+var sqliteMigrations = []string{
+ "", // migration #0 is reserved for schema initialization
+ "ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
+ "ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
+ "ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
+ "ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
+ "ALTER TABLE User ADD COLUMN admin INTEGER NOT NULL DEFAULT 0",
+ `
+ CREATE TABLE IF NOT EXISTS UserNew (
+ id INTEGER PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin INTEGER NOT NULL DEFAULT 0
+ );
+ INSERT INTO UserNew SELECT rowid, username, password, admin FROM User;
+ DROP TABLE User;
+ ALTER TABLE UserNew RENAME TO User;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user INTEGER NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BLOB DEFAULT NULL,
+ sasl_external_key BLOB DEFAULT NULL,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT Network.id, name, User.id as user, addr, nick,
+ Network.username, realname, pass, connect_commands,
+ sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key
+ FROM Network
+ JOIN User ON Network.user = User.username;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ ALTER TABLE Channel ADD COLUMN relay_detached INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN reattach_on INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
+ ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255),
+ internal_msgid VARCHAR(255) NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+ );
+ `,
+ "ALTER TABLE Channel ADD COLUMN detached_internal_msgid VARCHAR(255)",
+ "ALTER TABLE Network ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
+ "ALTER TABLE User ADD COLUMN realname VARCHAR(255)",
+ `
+ CREATE TABLE IF NOT EXISTS NetworkNew (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+ );
+ INSERT INTO NetworkNew
+ SELECT id, name, user, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key,
+ enabled
+ FROM Network;
+ DROP TABLE Network;
+ ALTER TABLE NetworkNew RENAME TO Network;
+ `,
+ `
+ CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+ );
+ `,
+}
+
+type SqliteDB struct {
+ lock sync.RWMutex
+ db *sql.DB
+}
+
+func OpenSqliteDB(source string) (Database, error) {
+ sqlSqliteDB, err := sql.Open("sqlite", source)
+ if err != nil {
+ return nil, err
+ }
+
+ db := &SqliteDB{db: sqlSqliteDB}
+ if err := db.upgrade(); err != nil {
+ sqlSqliteDB.Close()
+ return nil, err
+ }
+
+ return db, nil
+}
+
+func (db *SqliteDB) Close() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+ return db.db.Close()
+}
+
+func (db *SqliteDB) upgrade() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ var version int
+ if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
+ return fmt.Errorf("failed to query schema version: %v", err)
+ }
+
+ if version == len(sqliteMigrations) {
+ return nil
+ } else if version > len(sqliteMigrations) {
+ return fmt.Errorf("suika (version %d) older than schema (version %d)", len(sqliteMigrations), version)
+ }
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ if version == 0 {
+ if _, err := tx.Exec(sqliteSchema); err != nil {
+ return fmt.Errorf("failed to initialize schema: %v", err)
+ }
+ } else {
+ for i := version; i < len(sqliteMigrations); i++ {
+ if _, err := tx.Exec(sqliteMigrations[i]); err != nil {
+ return fmt.Errorf("failed to execute migration #%v: %v", i, err)
+ }
+ }
+ }
+
+ // For some reason prepared statements don't work here
+ _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations)))
+ if err != nil {
+ return fmt.Errorf("failed to bump schema version: %v", err)
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) Stats(ctx context.Context) (*DatabaseStats, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var stats DatabaseStats
+ row := db.db.QueryRowContext(ctx, `SELECT
+ (SELECT COUNT(*) FROM User) AS users,
+ (SELECT COUNT(*) FROM Network) AS networks,
+ (SELECT COUNT(*) FROM Channel) AS channels`)
+ if err := row.Scan(&stats.Users, &stats.Networks, &stats.Channels); err != nil {
+ return nil, err
+ }
+
+ return &stats, nil
+}
+
+func toNullString(s string) sql.NullString {
+ return sql.NullString{
+ String: s,
+ Valid: s != "",
+ }
+}
+
+func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx,
+ "SELECT id, username, password, admin, realname FROM User")
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var users []User
+ for rows.Next() {
+ var user User
+ var password, realname sql.NullString
+ if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ users = append(users, user)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return users, nil
+}
+
+func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ user := &User{Username: username}
+
+ var password, realname sql.NullString
+ row := db.db.QueryRowContext(ctx,
+ "SELECT id, password, admin, realname FROM User WHERE username = ?",
+ username)
+ if err := row.Scan(&user.ID, &password, &user.Admin, &realname); err != nil {
+ return nil, err
+ }
+ user.Password = password.String
+ user.Realname = realname.String
+ return user, nil
+}
+
+func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("username", user.Username),
+ sql.Named("password", toNullString(user.Password)),
+ sql.Named("admin", user.Admin),
+ sql.Named("realname", toNullString(user.Realname)),
+ }
+
+ var err error
+ if user.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE User SET password = :password, admin = :admin,
+ realname = :realname WHERE username = :username`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ User(username, password, admin, realname)
+ VALUES (:username, :password, :admin, :realname)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ user.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
+
+func (db *SqliteDB) DeleteUser(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM DeliveryReceipt
+ WHERE id IN (
+ SELECT DeliveryReceipt.id
+ FROM DeliveryReceipt
+ JOIN Network ON DeliveryReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM ReadReceipt
+ WHERE id IN (
+ SELECT ReadReceipt.id
+ FROM ReadReceipt
+ JOIN Network ON ReadReceipt.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, `DELETE FROM Channel
+ WHERE id IN (
+ SELECT Channel.id
+ FROM Channel
+ JOIN Network ON Channel.network = Network.id
+ WHERE Network.user = ?
+ )`, id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE user = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM User WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
+ sasl_external_cert, sasl_external_key, enabled
+ FROM Network
+ WHERE user = ?`,
+ userID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var networks []Network
+ for rows.Next() {
+ var net Network
+ var name, nick, username, realname, pass, connectCommands sql.NullString
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname,
+ &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
+ &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.Enabled)
+ if err != nil {
+ return nil, err
+ }
+ net.Name = name.String
+ net.Nick = nick.String
+ net.Username = username.String
+ net.Realname = realname.String
+ net.Pass = pass.String
+ if connectCommands.Valid {
+ net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
+ }
+ net.SASL.Mechanism = saslMechanism.String
+ net.SASL.Plain.Username = saslPlainUsername.String
+ net.SASL.Plain.Password = saslPlainPassword.String
+ networks = append(networks, net)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return networks, nil
+}
+
+func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Network) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
+ if network.SASL.Mechanism != "" {
+ saslMechanism = toNullString(network.SASL.Mechanism)
+ switch network.SASL.Mechanism {
+ case "PLAIN":
+ saslPlainUsername = toNullString(network.SASL.Plain.Username)
+ saslPlainPassword = toNullString(network.SASL.Plain.Password)
+ network.SASL.External.CertBlob = nil
+ network.SASL.External.PrivKeyBlob = nil
+ case "EXTERNAL":
+ // keep saslPlain* nil
+ default:
+ return fmt.Errorf("suika: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
+ }
+ }
+
+ args := []interface{}{
+ sql.Named("name", toNullString(network.Name)),
+ sql.Named("addr", network.Addr),
+ sql.Named("nick", toNullString(network.Nick)),
+ sql.Named("username", toNullString(network.Username)),
+ sql.Named("realname", toNullString(network.Realname)),
+ sql.Named("pass", toNullString(network.Pass)),
+ sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
+ sql.Named("sasl_mechanism", saslMechanism),
+ sql.Named("sasl_plain_username", saslPlainUsername),
+ sql.Named("sasl_plain_password", saslPlainPassword),
+ sql.Named("sasl_external_cert", network.SASL.External.CertBlob),
+ sql.Named("sasl_external_key", network.SASL.External.PrivKeyBlob),
+ sql.Named("enabled", network.Enabled),
+
+ sql.Named("id", network.ID), // only for UPDATE
+ sql.Named("user", userID), // only for INSERT
+ }
+
+ var err error
+ if network.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE Network
+ SET name = :name, addr = :addr, nick = :nick, username = :username,
+ realname = :realname, pass = :pass, connect_commands = :connect_commands,
+ sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
+ sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
+ enabled = :enabled
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO Network(user, name, addr, nick, username, realname, pass,
+ connect_commands, sasl_mechanism, sasl_plain_username,
+ sasl_plain_password, sasl_external_cert, sasl_external_key, enabled)
+ VALUES (:user, :name, :addr, :nick, :username, :realname, :pass,
+ :connect_commands, :sasl_mechanism, :sasl_plain_username,
+ :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :enabled)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ network.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM ReadReceipt WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Channel WHERE network = ?", id)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM Network WHERE id = ?", id)
+ if err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) ListChannels(ctx context.Context, networkID int64) ([]Channel, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `SELECT
+ id, name, key, detached, detached_internal_msgid,
+ relay_detached, reattach_on, detach_after, detach_on
+ FROM Channel
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var channels []Channel
+ for rows.Next() {
+ var ch Channel
+ var key, detachedInternalMsgID sql.NullString
+ var detachAfter int64
+ if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached, &detachedInternalMsgID, &ch.RelayDetached, &ch.ReattachOn, &detachAfter, &ch.DetachOn); err != nil {
+ return nil, err
+ }
+ ch.Key = key.String
+ ch.DetachedInternalMsgID = detachedInternalMsgID.String
+ ch.DetachAfter = time.Duration(detachAfter) * time.Second
+ channels = append(channels, ch)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return channels, nil
+}
+
+func (db *SqliteDB) StoreChannel(ctx context.Context, networkID int64, ch *Channel) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("network", networkID),
+ sql.Named("name", ch.Name),
+ sql.Named("key", toNullString(ch.Key)),
+ sql.Named("detached", ch.Detached),
+ sql.Named("detached_internal_msgid", toNullString(ch.DetachedInternalMsgID)),
+ sql.Named("relay_detached", ch.RelayDetached),
+ sql.Named("reattach_on", ch.ReattachOn),
+ sql.Named("detach_after", int64(math.Ceil(ch.DetachAfter.Seconds()))),
+ sql.Named("detach_on", ch.DetachOn),
+
+ sql.Named("id", ch.ID), // only for UPDATE
+ }
+
+ var err error
+ if ch.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `UPDATE Channel
+ SET network = :network, name = :name, key = :key, detached = :detached,
+ detached_internal_msgid = :detached_internal_msgid, relay_detached = :relay_detached,
+ reattach_on = :reattach_on, detach_after = :detach_after, detach_on = :detach_on
+ WHERE id = :id`, args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `INSERT INTO Channel(network, name, key, detached, detached_internal_msgid, relay_detached, reattach_on, detach_after, detach_on)
+ VALUES (:network, :name, :key, :detached, :detached_internal_msgid, :relay_detached, :reattach_on, :detach_after, :detach_on)`, args...)
+ if err != nil {
+ return err
+ }
+ ch.ID, err = res.LastInsertId()
+ }
+ return err
+}
+
+func (db *SqliteDB) DeleteChannel(ctx context.Context, id int64) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ _, err := db.db.ExecContext(ctx, "DELETE FROM Channel WHERE id = ?", id)
+ return err
+}
+
+func (db *SqliteDB) ListDeliveryReceipts(ctx context.Context, networkID int64) ([]DeliveryReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT id, target, client, internal_msgid
+ FROM DeliveryReceipt
+ WHERE network = ?`, networkID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var receipts []DeliveryReceipt
+ for rows.Next() {
+ var rcpt DeliveryReceipt
+ var client sql.NullString
+ if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
+ return nil, err
+ }
+ rcpt.Client = client.String
+ receipts = append(receipts, rcpt)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ return receipts, nil
+}
+
+func (db *SqliteDB) StoreClientDeliveryReceipts(ctx context.Context, networkID int64, client string, receipts []DeliveryReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ tx, err := db.db.Begin()
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ? AND client IS ?",
+ networkID, toNullString(client))
+ if err != nil {
+ return err
+ }
+
+ for i := range receipts {
+ rcpt := &receipts[i]
+
+ res, err := tx.ExecContext(ctx, `
+ INSERT INTO DeliveryReceipt(network, target, client, internal_msgid)
+ VALUES (:network, :target, :client, :internal_msgid)`,
+ sql.Named("network", networkID),
+ sql.Named("target", rcpt.Target),
+ sql.Named("client", toNullString(client)),
+ sql.Named("internal_msgid", rcpt.InternalMsgID))
+ if err != nil {
+ return err
+ }
+ rcpt.ID, err = res.LastInsertId()
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ receipt := &ReadReceipt{
+ Target: name,
+ }
+
+ row := db.db.QueryRowContext(ctx, `
+ SELECT id, timestamp FROM ReadReceipt WHERE network = :network AND target = :target`,
+ sql.Named("network", networkID),
+ sql.Named("target", name),
+ )
+ var timestamp string
+ if err := row.Scan(&receipt.ID, ×tamp); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if t, err := time.Parse(serverTimeLayout, timestamp); err != nil {
+ return nil, err
+ } else {
+ receipt.Timestamp = t
+ }
+ return receipt, nil
+}
+
+func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
+ defer cancel()
+
+ args := []interface{}{
+ sql.Named("id", receipt.ID),
+ sql.Named("timestamp", formatServerTime(receipt.Timestamp)),
+ sql.Named("network", networkID),
+ sql.Named("target", receipt.Target),
+ }
+
+ var err error
+ if receipt.ID != 0 {
+ _, err = db.db.ExecContext(ctx, `
+ UPDATE ReadReceipt SET timestamp = :timestamp WHERE id = :id`,
+ args...)
+ } else {
+ var res sql.Result
+ res, err = db.db.ExecContext(ctx, `
+ INSERT INTO
+ ReadReceipt(network, target, timestamp)
+ VALUES (:network, :target, :timestamp)`,
+ args...)
+ if err != nil {
+ return err
+ }
+ receipt.ID, err = res.LastInsertId()
+ }
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "database/sql"
+ "testing"
+)
+
+// SQLite version 0 schema. DO NOT EDIT.
+const sqliteV0Schema = `
+CREATE TABLE User (
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255)
+);
+
+CREATE TABLE Network (
+ id INTEGER PRIMARY KEY,
+ name VARCHAR(255),
+ user VARCHAR(255) NOT NULL,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255) NOT NULL,
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ sasl_mechanism VARCHAR(255),
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+
+CREATE TABLE Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+PRAGMA user_version = 1;
+`
+
+func TestSqliteMigrations(t *testing.T) {
+ sqlDB, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+
+ if _, err := sqlDB.Exec(sqliteV0Schema); err != nil {
+ t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
+ }
+
+ db := &SqliteDB{db: sqlDB}
+ defer db.Close()
+
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("SqliteDB.Upgrade() failed: %v", err)
+ }
+}
--- /dev/null
+// Package suika is a hard-fork of the 0.3 series of soju, an user-friendly IRC bouncer in Go.
+//
+// # Copyright (C) 2020 The soju Contributors
+// # Copyright (C) 2023-present Izuru Yakumo et al.
+//
+// suika is covered by the AGPLv3 license:
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <https://www.gnu.org/licenses/>.
+package suika
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-CONFIG 5
+.Os
+.Sh NAME
+.Nm suika-config
+.Nd Configuration file for the IRC bouncer
+.Sh SYNOPSIS
+.Bk -words
+listen ircs://
+.Pp
+tls cert.pem key.pem
+.Pp
+hostname example.org
+.Ek
+.Sh DESCRIPTION
+This document describes the format of the configuration
+file used by
+.Xr suika 1
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It listen Ar uri
+With this you can control on what
+ports/protocols
+.Xr suika 1
+listens on, it supports
+irc (cleartext IRC), ircs (IRC with TLS), and unix
+(IRC over Unix domain sockets)
+.It hostname Ar hostname
+Server hostname, if unset, the system one is used.
+.It title Ar title
+Server title, this will be sent as the ISUPPORT NETWORK value when
+clients don't select a specific network.
+.It tls Ar cert Ar key
+Enable TLS support, the certificate and key files must be
+PEM-encoded.
+.It db Ar driver Ar path
+Set the database driver for user, network and channel storage.
+By default a SQLite 3 database is opened in
+.Pa ./suika.db
+Supported drivers are sqlite and postgres, the former
+expects a path to the database file, and the latter
+a space-separated list of key=value parameters,
+e.g. host=localhost dbname=suika
+.It log fs Ar path
+Path to the bouncer logs directory, or empty to disable
+logging.
+By default, logging is disabled.
+.It max-user-networks Ar limit
+Maximum number of networks per user, by default
+there is no limit.
+.It motd Ar path
+Path to the MOTD file, its contents are sent to clients
+which aren't bound to a particular network.
+By default, no MOTD is sent.
+.It multi-upstream-mode Ar bool
+Globally enable or disable multi-upstream mode.
+By default, it is enabled.
+.It upstream-user-ip Ar cidr
+Enable per-user-IP addresses.
+One IPv4 and/or one IPv6 range can be specified in CIDR notation.
+One IP address per range will be assigned to each user as the
+source address when connecting to an upstream network.
+This can be useful to avoid having the whole bouncer banned from
+an upstream network because of one malicious user.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA-ZNC-IMPORT 1
+.Os
+.Sh NAME
+.Nm suika-znc-import
+.Nd Migration utility for moving from ZNC
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar suika config file
+.Op Fl user Ar username
+.Op Fl network Ar name
+.Sh DESCRIPTION
+Imports configuration from a ZNC file.
+Users and networks are merged if they already exist in the
+.Xr suika 1
+database.
+ZNC settings overwrite existing
+.Xr suika 1
+settings
+.Sh OPTIONS
+.Bl -tag -width Ds
+.It config Ar suika config file
+Path to
+.Xr suika-config 5
+.It user Ar username
+Limit import to username, may be specified multiple times.
+It network Ar name
+Limit import to network, may be specified multiple times.
+.El
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKA 1
+.Os
+.Sh NAME
+.Nm suika
+.Nd A drunk as hell IRC bouncer, named after Suika Ibuki from Touhou Project
+.Sh SYNOPSIS
+.Nm
+.Op Fl config Ar path
+.Op Fl debug
+.Op Fl listen Ar uri
+.Sh DESCRIPTION
+.Nm
+is an user-friendly IRC bouncer, it connects to upstream
+IRC servers on behalf of the user to provide extra features.
+.Bl -tag -width 6n
+.It Multiple separate users sharing the same bouncer
+.It Clients connecting to multiple upstream servers (via a single connection)
+.It Sending the backlog with per-client buffers
+.El
+.Pp
+When joining a channel, the channel will be saved
+and automatically joined on the next connection.
+When registering or authenticating with NickServ, the credentials will be saved
+and automatically used on the next connection if the server supports SASL.
+When parting a channel with the reason "detach", the channel will be
+detached instead of being left.
+When all clients are disconnected from the bouncer,
+the user is automatically marked as away.
+.Pp
+.Nm
+supports two connection modes:
+.Bl -tag -width 6n
+.It Single upstream mode
+One downstream connection maps to one upstream connection
+.Pp
+To enable this mode, connect to the bouncer
+with the username "<username>/<network>".
+.Pp
+If the bouncer isn't connected to the upstream server,
+it will get automatically added.
+.Pp
+Then channels can be joined and parted as if
+you were directly connected to the upstream server.
+.It Multiple upstream mode
+One downstream connection maps to multiple upstream connections.
+Channels and nicks are suffixed with the network name.
+To join a channel, you need to use the suffix too: /join #channel/network.
+Same applies to messages sent to users.
+.El
+.Pp
+For per-client history to work, clients need to indicate their name.
+This can be done by adding a "@<client>" suffix to the username.
+.Pp
+.Nm
+will reload the configuration file, the TLS certificate/key and
+the MOTD file when it receives the HUP signal.
+The configuration options listen, db and log cannot be reloaded.
+.Pp
+Administrators can broadcast a message to all bouncer users via
+/notice $<hostname> <text>, or via /notice $<text> in multi-upstream mode.
+All currently connected bouncer users will receive the message
+from the special BouncerServ service.
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+.Dd $Mdocdate$
+.Dt SUIKADB 1
+.Os
+.Sh NAME
+.Nm suikadb
+.Nd Basic user manipulation for
+.Xr suika 1
+.Sh SYNOPSIS
+.Nm
+.Op create-user
+.Op change-password
+.Sh AUTHORS
+.An Simon Ser Aq Mt contact@emersion.fr
+.An The soju Contributors
+.Sh MAINTAINERS
+.An Izuru Yakumo Aq Mt yakumo.izuru@chaotic.ninja
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+type ircError struct {
+ Message *irc.Message
+}
+
+func (err ircError) Error() string {
+ return err.Message.String()
+}
+
+func newUnknownCommandError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{
+ "*",
+ cmd,
+ "Unknown command",
+ },
+ }}
+}
+
+func newNeedMoreParamsError(cmd string) ircError {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NEEDMOREPARAMS,
+ Params: []string{
+ "*",
+ cmd,
+ "Not enough parameters",
+ },
+ }}
+}
+
+func newChatHistoryError(subcommand string, target string) ircError {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, target, "Messages could not be retrieved"},
+ }}
+}
+
+// authError is an authentication error.
+type authError struct {
+ // Internal error cause. This will not be revealed to the user.
+ err error
+ // Error cause which can safely be sent to the user without compromising
+ // security.
+ reason string
+}
+
+func (err *authError) Error() string {
+ return err.err.Error()
+}
+
+func (err *authError) Unwrap() error {
+ return err.err
+}
+
+// authErrorReason returns the user-friendly reason of an authentication
+// failure.
+func authErrorReason(err error) string {
+ if authErr, ok := err.(*authError); ok {
+ return authErr.reason
+ } else {
+ return "Authentication failed"
+ }
+}
+
+func newInvalidUsernameOrPasswordError(err error) error {
+ return &authError{
+ err: err,
+ reason: "Invalid username or password",
+ }
+}
+
+func parseBouncerNetID(subcommand, s string) (int64, error) {
+ id, err := strconv.ParseInt(s, 10, 64)
+ if err != nil {
+ return 0, ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, s, "Invalid network ID"},
+ }}
+ }
+ return id, nil
+}
+
+func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) {
+ u, err := network.URL()
+ if err != nil {
+ return
+ }
+
+ hasHostPort := true
+ switch u.Scheme {
+ case "ircs":
+ attrs["tls"] = irc.TagValue("1")
+ case "irc":
+ attrs["tls"] = irc.TagValue("0")
+ default: // e.g. unix://
+ hasHostPort = false
+ }
+ if host, port, err := net.SplitHostPort(u.Host); err == nil && hasHostPort {
+ attrs["host"] = irc.TagValue(host)
+ attrs["port"] = irc.TagValue(port)
+ } else if hasHostPort {
+ attrs["host"] = irc.TagValue(u.Host)
+ }
+}
+
+func getNetworkAttrs(network *network) irc.Tags {
+ state := "disconnected"
+ if uc := network.conn; uc != nil {
+ state = "connected"
+ }
+
+ attrs := irc.Tags{
+ "name": irc.TagValue(network.GetName()),
+ "state": irc.TagValue(state),
+ "nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)),
+ }
+
+ if network.Username != "" {
+ attrs["username"] = irc.TagValue(network.Username)
+ }
+ if realname := GetRealname(&network.user.User, &network.Network); realname != "" {
+ attrs["realname"] = irc.TagValue(realname)
+ }
+
+ fillNetworkAddrAttrs(attrs, &network.Network)
+
+ return attrs
+}
+
+func networkAddrFromAttrs(attrs irc.Tags) string {
+ host, ok := attrs.GetTag("host")
+ if !ok {
+ return ""
+ }
+
+ addr := host
+ if port, ok := attrs.GetTag("port"); ok {
+ addr += ":" + port
+ }
+
+ if tlsStr, ok := attrs.GetTag("tls"); ok && tlsStr == "0" {
+ addr = "irc://" + tlsStr
+ }
+
+ return addr
+}
+
+func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error {
+ addrAttrs := irc.Tags{}
+ fillNetworkAddrAttrs(addrAttrs, record)
+
+ updateAddr := false
+ for k, v := range attrs {
+ s := string(v)
+ switch k {
+ case "host", "port", "tls":
+ updateAddr = true
+ addrAttrs[k] = v
+ case "name":
+ record.Name = s
+ case "nickname":
+ record.Nick = s
+ case "username":
+ record.Username = s
+ case "realname":
+ record.Realname = s
+ case "pass":
+ record.Pass = s
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ATTRIBUTE", subcommand, k, "Unknown attribute"},
+ }}
+ }
+ }
+
+ if updateAddr {
+ record.Addr = networkAddrFromAttrs(addrAttrs)
+ if record.Addr == "" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "NEED_ATTRIBUTE", subcommand, "host", "Missing required host attribute"},
+ }}
+ }
+ }
+
+ return nil
+}
+
+// illegalNickChars is the list of characters forbidden in a nickname.
+//
+// ' ' and ':' break the IRC message wire format
+// '@' and '!' break prefixes
+// '*' breaks masks and is the reserved nickname for registration
+// '?' breaks masks
+// '$' breaks server masks in PRIVMSG/NOTICE
+// ',' breaks lists
+// '.' is reserved for server names
+const illegalNickChars = " :@!*?$,."
+
+// permanentDownstreamCaps is the list of always-supported downstream
+// capabilities.
+var permanentDownstreamCaps = map[string]string{
+ "batch": "",
+ "cap-notify": "",
+ "echo-message": "",
+ "invite-notify": "",
+ "message-tags": "",
+ "server-time": "",
+ "setname": "",
+
+ "soju.im/bouncer-networks": "",
+ "soju.im/bouncer-networks-notify": "",
+ "soju.im/read": "",
+}
+
+// needAllDownstreamCaps is the list of downstream capabilities that
+// require support from all upstreams to be enabled
+var needAllDownstreamCaps = map[string]string{
+ "account-notify": "",
+ "account-tag": "",
+ "away-notify": "",
+ "extended-join": "",
+ "multi-prefix": "",
+
+ "draft/extended-monitor": "",
+}
+
+// passthroughIsupport is the set of ISUPPORT tokens that are directly passed
+// through from the upstream server to downstream clients.
+//
+// This is only effective in single-upstream mode.
+var passthroughIsupport = map[string]bool{
+ "AWAYLEN": true,
+ "BOT": true,
+ "CHANLIMIT": true,
+ "CHANMODES": true,
+ "CHANNELLEN": true,
+ "CHANTYPES": true,
+ "CLIENTTAGDENY": true,
+ "ELIST": true,
+ "EXCEPTS": true,
+ "EXTBAN": true,
+ "HOSTLEN": true,
+ "INVEX": true,
+ "KICKLEN": true,
+ "MAXLIST": true,
+ "MAXTARGETS": true,
+ "MODES": true,
+ "MONITOR": true,
+ "NAMELEN": true,
+ "NETWORK": true,
+ "NICKLEN": true,
+ "PREFIX": true,
+ "SAFELIST": true,
+ "TARGMAX": true,
+ "TOPICLEN": true,
+ "USERLEN": true,
+ "UTF8ONLY": true,
+ "WHOX": true,
+}
+
+type downstreamSASL struct {
+ server sasl.Server
+ plainUsername, plainPassword string
+ pendingResp bytes.Buffer
+}
+
+type downstreamConn struct {
+ conn
+
+ id uint64
+
+ registered bool
+ user *user
+ nick string
+ nickCM string
+ rawUsername string
+ networkName string
+ clientName string
+ realname string
+ hostname string
+ account string // RPL_LOGGEDIN/OUT state
+ password string // empty after authentication
+ network *network // can be nil
+ isMultiUpstream bool
+
+ negotiatingCaps bool
+ capVersion int
+ supportedCaps map[string]string
+ caps map[string]bool
+ sasl *downstreamSASL
+
+ lastBatchRef uint64
+
+ monitored casemapMap
+}
+
+func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
+ remoteAddr := ic.RemoteAddr().String()
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
+ options := connOptions{Logger: logger}
+ dc := &downstreamConn{
+ conn: *newConn(srv, ic, &options),
+ id: id,
+ nick: "*",
+ nickCM: "*",
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ monitored: newCasemapMap(0),
+ }
+ dc.hostname = remoteAddr
+ if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
+ dc.hostname = host
+ }
+ for k, v := range permanentDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ dc.supportedCaps["sasl"] = "PLAIN"
+ // TODO: this is racy, we should only enable chathistory after
+ // authentication and then check that user.msgStore implements
+ // chatHistoryMessageStore
+ if srv.Config().LogPath != "" {
+ dc.supportedCaps["draft/chathistory"] = ""
+ }
+ return dc
+}
+
+func (dc *downstreamConn) prefix() *irc.Prefix {
+ return &irc.Prefix{
+ Name: dc.nick,
+ User: dc.user.Username,
+ Host: dc.hostname,
+ }
+}
+
+func (dc *downstreamConn) forEachNetwork(f func(*network)) {
+ if dc.network != nil {
+ f(dc.network)
+ } else if dc.isMultiUpstream {
+ for _, network := range dc.user.networks {
+ f(network)
+ }
+ }
+}
+
+func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ dc.user.forEachUpstream(func(uc *upstreamConn) {
+ if dc.network != nil && uc.network != dc.network {
+ return
+ }
+ f(uc)
+ })
+}
+
+// upstream returns the upstream connection, if any. If there are zero or if
+// there are multiple upstream connections, it returns nil.
+func (dc *downstreamConn) upstream() *upstreamConn {
+ if dc.network == nil {
+ return nil
+ }
+ return dc.network.conn
+}
+
+func isOurNick(net *network, nick string) bool {
+ // TODO: this doesn't account for nick changes
+ if net.conn != nil {
+ return net.casemap(nick) == net.conn.nickCM
+ }
+ // We're not currently connected to the upstream connection, so we don't
+ // know whether this name is our nickname. Best-effort: use the network's
+ // configured nickname and hope it was the one being used when we were
+ // connected.
+ return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network))
+}
+
+// marshalEntity converts an upstream entity name (ie. channel or nick) into a
+// downstream entity name.
+//
+// This involves adding a "/<network>" suffix if the entity isn't the current
+// user.
+func (dc *downstreamConn) marshalEntity(net *network, name string) string {
+ if isOurNick(net, name) {
+ return dc.nick
+ }
+ name = partialCasemap(net.casemap, name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal an entity for another network")
+ }
+ return name
+ }
+ return name + "/" + net.GetName()
+}
+
+func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix {
+ if isOurNick(net, prefix.Name) {
+ return dc.prefix()
+ }
+ prefix.Name = partialCasemap(net.casemap, prefix.Name)
+ if dc.network != nil {
+ if dc.network != net {
+ panic("suika: tried to marshal a user prefix for another network")
+ }
+ return prefix
+ }
+ return &irc.Prefix{
+ Name: prefix.Name + "/" + net.GetName(),
+ User: prefix.User,
+ Host: prefix.Host,
+ }
+}
+
+// unmarshalEntityNetwork converts a downstream entity name (ie. channel or
+// nick) into an upstream entity name.
+//
+// This involves removing the "/<network>" suffix.
+func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) {
+ if dc.network != nil {
+ return dc.network, name, nil
+ }
+ if !dc.isMultiUpstream {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"},
+ }}
+ }
+
+ var net *network
+ if i := strings.LastIndexByte(name, '/'); i >= 0 {
+ network := name[i+1:]
+ name = name[:i]
+
+ for _, n := range dc.user.networks {
+ if network == n.GetName() {
+ net = n
+ break
+ }
+ }
+ }
+
+ if net == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Missing network suffix in name"},
+ }}
+ }
+
+ return net, name, nil
+}
+
+// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the
+// upstream connection and fails if the upstream is disconnected.
+func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) {
+ net, name, err := dc.unmarshalEntityNetwork(name)
+ if err != nil {
+ return nil, "", err
+ }
+
+ if net.conn == nil {
+ return nil, "", ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "Disconnected from upstream network"},
+ }}
+ }
+
+ return net.conn, name, nil
+}
+
+func (dc *downstreamConn) unmarshalText(uc *upstreamConn, text string) string {
+ if dc.upstream() != nil {
+ return text
+ }
+ // TODO: smarter parsing that ignores URLs
+ return strings.ReplaceAll(text, "/"+uc.network.GetName(), "")
+}
+
+func (dc *downstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := dc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (dc *downstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := dc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventDownstreamMessage{msg, dc}
+ }
+
+ return nil
+}
+
+// SendMessage sends an outgoing message.
+//
+// This can only called from the user goroutine.
+func (dc *downstreamConn) SendMessage(msg *irc.Message) {
+ if !dc.caps["message-tags"] {
+ if msg.Command == "TAGMSG" {
+ return
+ }
+ msg = msg.Copy()
+ for name := range msg.Tags {
+ supported := false
+ switch name {
+ case "time":
+ supported = dc.caps["server-time"]
+ case "account":
+ supported = dc.caps["account"]
+ }
+ if !supported {
+ delete(msg.Tags, name)
+ }
+ }
+ }
+ if !dc.caps["batch"] && msg.Tags["batch"] != "" {
+ msg = msg.Copy()
+ delete(msg.Tags, "batch")
+ }
+ if msg.Command == "JOIN" && !dc.caps["extended-join"] {
+ msg.Params = msg.Params[:1]
+ }
+ if msg.Command == "SETNAME" && !dc.caps["setname"] {
+ return
+ }
+ if msg.Command == "AWAY" && !dc.caps["away-notify"] {
+ return
+ }
+ if msg.Command == "ACCOUNT" && !dc.caps["account-notify"] {
+ return
+ }
+ if msg.Command == "READ" && !dc.caps["soju.im/read"] {
+ return
+ }
+
+ dc.conn.SendMessage(context.TODO(), msg)
+}
+
+func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) {
+ dc.lastBatchRef++
+ ref := fmt.Sprintf("%v", dc.lastBatchRef)
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Tags: tags,
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: append([]string{"+" + ref, typ}, params...),
+ })
+ }
+
+ f(irc.TagValue(ref))
+
+ if dc.caps["batch"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BATCH",
+ Params: []string{"-" + ref},
+ })
+ }
+}
+
+// sendMessageWithID sends an outgoing message with the specified internal ID.
+func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) {
+ dc.SendMessage(msg)
+
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// advanceMessageWithID advances history to the specified message ID without
+// sending a message. This is useful e.g. for self-messages when echo-message
+// isn't enabled.
+func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) {
+ if id == "" || !dc.messageSupportsBacklog(msg) {
+ return
+ }
+
+ dc.sendPing(id)
+}
+
+// ackMsgID acknowledges that a message has been received.
+func (dc *downstreamConn) ackMsgID(id string) {
+ netID, entity, err := parseMsgID(id, nil)
+ if err != nil {
+ dc.logger.Printf("failed to ACK message ID %q: %v", id, err)
+ return
+ }
+
+ network := dc.user.getNetworkByID(netID)
+ if network == nil {
+ return
+ }
+
+ network.delivered.StoreID(entity, dc.clientName, id)
+}
+
+func (dc *downstreamConn) sendPing(msgID string) {
+ token := "suika-msgid-" + msgID
+ dc.SendMessage(&irc.Message{
+ Command: "PING",
+ Params: []string{token},
+ })
+}
+
+func (dc *downstreamConn) handlePong(token string) {
+ if !strings.HasPrefix(token, "suika-msgid-") {
+ dc.logger.Printf("received unrecognized PONG token %q", token)
+ return
+ }
+ msgID := strings.TrimPrefix(token, "suika-msgid-")
+ dc.ackMsgID(msgID)
+}
+
+// marshalMessage re-formats a message coming from an upstream connection so
+// that it's suitable for being sent on this downstream connection. Only
+// messages that may appear in logs are supported, except MODE messages which
+// may only appear in single-upstream mode.
+func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message {
+ msg = msg.Copy()
+ msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix)
+
+ if dc.network != nil {
+ return msg
+ }
+
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE", "TAGMSG":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "NICK":
+ // Nick change for another user
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "JOIN", "PART":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "KICK":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ msg.Params[1] = dc.marshalEntity(net, msg.Params[1])
+ case "TOPIC":
+ msg.Params[0] = dc.marshalEntity(net, msg.Params[0])
+ case "QUIT", "SETNAME":
+ // This space is intentionally left blank
+ default:
+ panic(fmt.Sprintf("unexpected %q message", msg.Command))
+ }
+
+ return msg
+}
+
+func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ ctx, cancel := dc.conn.NewContext(ctx)
+ defer cancel()
+
+ ctx, cancel = context.WithTimeout(ctx, handleDownstreamMessageTimeout)
+ defer cancel()
+
+ switch msg.Command {
+ case "QUIT":
+ return dc.Close()
+ default:
+ if dc.registered {
+ return dc.handleMessageRegistered(ctx, msg)
+ } else {
+ return dc.handleMessageUnregistered(ctx, msg)
+ }
+ }
+}
+
+func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "NICK":
+ var nick string
+ if err := parseMessageParams(msg, &nick); err != nil {
+ return err
+ }
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, nick, "contains illegal characters"},
+ }}
+ }
+ nickCM := casemapASCII(nick)
+ if nickCM == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
+ }}
+ }
+ dc.nick = nick
+ dc.nickCM = nickCM
+ case "USER":
+ if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil {
+ return err
+ }
+ case "PASS":
+ if err := parseMessageParams(msg, &dc.password); err != nil {
+ return err
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "AUTHENTICATE":
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ } else if credentials == nil {
+ break
+ }
+
+ if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil {
+ dc.logger.Printf("SASL authentication error for user %q: %v", credentials.plainUsername, err)
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, authErrorReason(err)},
+ })
+ break
+ }
+
+ // Technically we should send RPL_LOGGEDIN here. However we use
+ // RPL_LOGGEDIN to mirror the upstream connection status. Let's
+ // see how many clients that breaks. See:
+ // https://github.com/ircv3/ircv3-specifications/pull/476
+ dc.endSASL(nil)
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+
+ if dc.user == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "ACCOUNT_REQUIRED", "BIND", "Authentication needed to bind to bouncer network"},
+ }}
+ }
+
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ var match *network
+ for _, net := range dc.user.networks {
+ if net.ID == id {
+ match = net
+ break
+ }
+ }
+ if match == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", idStr, "Unknown network ID"},
+ }}
+ }
+
+ dc.networkName = match.GetName()
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+ return newUnknownCommandError(msg.Command)
+ }
+ if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps {
+ return dc.register(ctx)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
+ cmd = strings.ToUpper(cmd)
+
+ switch cmd {
+ case "LS":
+ if len(args) > 0 {
+ var err error
+ if dc.capVersion, err = strconv.Atoi(args[0]); err != nil {
+ return err
+ }
+ }
+ if !dc.registered && dc.capVersion >= 302 {
+ // Let downstream show everything it supports, and trim
+ // down the available capabilities when upstreams are
+ // known.
+ for k, v := range needAllDownstreamCaps {
+ dc.supportedCaps[k] = v
+ }
+ }
+
+ caps := make([]string, 0, len(dc.supportedCaps))
+ for k, v := range dc.supportedCaps {
+ if dc.capVersion >= 302 && v != "" {
+ caps = append(caps, k+"="+v)
+ } else {
+ caps = append(caps, k)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LS", strings.Join(caps, " ")},
+ })
+
+ if dc.capVersion >= 302 {
+ // CAP version 302 implicitly enables cap-notify
+ dc.caps["cap-notify"] = true
+ }
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "LIST":
+ var caps []string
+ for name, enabled := range dc.caps {
+ if enabled {
+ caps = append(caps, name)
+ }
+ }
+
+ // TODO: multi-line replies
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "LIST", strings.Join(caps, " ")},
+ })
+ case "REQ":
+ if len(args) == 0 {
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Missing argument in CAP REQ command"},
+ }}
+ }
+
+ // TODO: atomically ack/nak the whole capability set
+ caps := strings.Fields(args[0])
+ ack := true
+ for _, name := range caps {
+ name = strings.ToLower(name)
+ enable := !strings.HasPrefix(name, "-")
+ if !enable {
+ name = strings.TrimPrefix(name, "-")
+ }
+
+ if enable == dc.caps[name] {
+ continue
+ }
+
+ _, ok := dc.supportedCaps[name]
+ if !ok {
+ ack = false
+ break
+ }
+
+ if name == "cap-notify" && dc.capVersion >= 302 && !enable {
+ // cap-notify cannot be disabled with CAP version 302
+ ack = false
+ break
+ }
+
+ dc.caps[name] = enable
+ }
+
+ reply := "NAK"
+ if ack {
+ reply = "ACK"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, reply, args[0]},
+ })
+
+ if !dc.registered {
+ dc.negotiatingCaps = true
+ }
+ case "END":
+ dc.negotiatingCaps = false
+ default:
+ return ircError{&irc.Message{
+ Command: err_invalidcapcmd,
+ Params: []string{dc.nick, cmd, "Unknown CAP command"},
+ }}
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) {
+ defer func() {
+ if err != nil {
+ dc.sasl = nil
+ }
+ }()
+
+ if !dc.caps["sasl"] {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "AUTHENTICATE requires the \"sasl\" capability to be enabled"},
+ }}
+ }
+ if len(msg.Params) == 0 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Missing AUTHENTICATE argument"},
+ }}
+ }
+ if msg.Params[0] == "*" {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ }}
+ }
+
+ var resp []byte
+ if dc.sasl == nil {
+ mech := strings.ToUpper(msg.Params[0])
+ var server sasl.Server
+ switch mech {
+ case "PLAIN":
+ server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
+ dc.sasl.plainUsername = username
+ dc.sasl.plainPassword = password
+ return nil
+ }))
+ default:
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, fmt.Sprintf("Unsupported SASL mechanism %q", mech)},
+ }}
+ }
+
+ dc.sasl = &downstreamSASL{server: server}
+ } else {
+ chunk := msg.Params[0]
+ if chunk == "+" {
+ chunk = ""
+ }
+
+ if dc.sasl.pendingResp.Len()+len(chunk) > 10*1024 {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Response too long"},
+ }}
+ }
+
+ dc.sasl.pendingResp.WriteString(chunk)
+
+ if len(chunk) == maxSASLLength {
+ return nil, nil // Multi-line response, wait for the next command
+ }
+
+ resp, err = base64.StdEncoding.DecodeString(dc.sasl.pendingResp.String())
+ if err != nil {
+ return nil, ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Invalid base64-encoded response"},
+ }}
+ }
+
+ dc.sasl.pendingResp.Reset()
+ }
+
+ challenge, done, err := dc.sasl.server.Next(resp)
+ if err != nil {
+ return nil, err
+ } else if done {
+ return dc.sasl, nil
+ } else {
+ challengeStr := "+"
+ if len(challenge) > 0 {
+ challengeStr = base64.StdEncoding.EncodeToString(challenge)
+ }
+
+ // TODO: multi-line messages
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "AUTHENTICATE",
+ Params: []string{challengeStr},
+ })
+ return nil, nil
+ }
+}
+
+func (dc *downstreamConn) endSASL(msg *irc.Message) {
+ if dc.sasl == nil {
+ return
+ }
+
+ dc.sasl = nil
+
+ if msg != nil {
+ dc.SendMessage(msg)
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_SASLSUCCESS,
+ Params: []string{dc.nick, "SASL authentication successful"},
+ })
+ }
+}
+
+func (dc *downstreamConn) setSupportedCap(name, value string) {
+ prevValue, hasPrev := dc.supportedCaps[name]
+ changed := !hasPrev || prevValue != value
+ dc.supportedCaps[name] = value
+
+ if !dc.caps["cap-notify"] || !changed {
+ return
+ }
+
+ cap := name
+ if value != "" && dc.capVersion >= 302 {
+ cap = name + "=" + value
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "NEW", cap},
+ })
+}
+
+func (dc *downstreamConn) unsetSupportedCap(name string) {
+ _, hasPrev := dc.supportedCaps[name]
+ delete(dc.supportedCaps, name)
+ delete(dc.caps, name)
+
+ if !dc.caps["cap-notify"] || !hasPrev {
+ return
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "CAP",
+ Params: []string{dc.nick, "DEL", name},
+ })
+}
+
+func (dc *downstreamConn) updateSupportedCaps() {
+ supportedCaps := make(map[string]bool)
+ for cap := range needAllDownstreamCaps {
+ supportedCaps[cap] = true
+ }
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for cap, supported := range supportedCaps {
+ supportedCaps[cap] = supported && uc.caps[cap]
+ }
+ })
+
+ for cap, supported := range supportedCaps {
+ if supported {
+ dc.setSupportedCap(cap, needAllDownstreamCaps[cap])
+ } else {
+ dc.unsetSupportedCap(cap)
+ }
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") {
+ dc.setSupportedCap("sasl", "PLAIN")
+ } else if dc.network != nil {
+ dc.unsetSupportedCap("sasl")
+ }
+
+ if uc := dc.upstream(); uc != nil && uc.caps["draft/account-registration"] {
+ // Strip "before-connect", because we require downstreams to be fully
+ // connected before attempting account registration.
+ values := strings.Split(uc.supportedCaps["draft/account-registration"], ",")
+ for i, v := range values {
+ if v == "before-connect" {
+ values = append(values[:i], values[i+1:]...)
+ break
+ }
+ }
+ dc.setSupportedCap("draft/account-registration", strings.Join(values, ","))
+ } else {
+ dc.unsetSupportedCap("draft/account-registration")
+ }
+
+ if _, ok := dc.user.msgStore.(chatHistoryMessageStore); ok && dc.network != nil {
+ dc.setSupportedCap("draft/event-playback", "")
+ } else {
+ dc.unsetSupportedCap("draft/event-playback")
+ }
+}
+
+func (dc *downstreamConn) updateNick() {
+ if uc := dc.upstream(); uc != nil && uc.nick != dc.nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ dc.nick = uc.nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+}
+
+func (dc *downstreamConn) updateRealname() {
+ if uc := dc.upstream(); uc != nil && uc.realname != dc.realname && dc.caps["setname"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{uc.realname},
+ })
+ dc.realname = uc.realname
+ }
+}
+
+func (dc *downstreamConn) updateAccount() {
+ var account string
+ if dc.network == nil {
+ account = dc.user.Username
+ } else if uc := dc.upstream(); uc != nil {
+ account = uc.account
+ } else {
+ return
+ }
+
+ if dc.account == account || !dc.caps["sasl"] {
+ return
+ }
+
+ if account != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDIN,
+ Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LOGGEDOUT,
+ Params: []string{dc.nick, dc.prefix().String(), "You are logged out"},
+ })
+ }
+
+ dc.account = account
+}
+
+func sanityCheckServer(ctx context.Context, addr string) error {
+ ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
+ defer cancel()
+
+ conn, err := new(tls.Dialer).DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ return conn.Close()
+}
+
+func unmarshalUsername(rawUsername string) (username, client, network string) {
+ username = rawUsername
+
+ i := strings.IndexAny(username, "/@")
+ j := strings.LastIndexAny(username, "/@")
+ if i >= 0 {
+ username = rawUsername[:i]
+ }
+ if j >= 0 {
+ if rawUsername[j] == '@' {
+ client = rawUsername[j+1:]
+ } else {
+ network = rawUsername[j+1:]
+ }
+ }
+ if i >= 0 && j >= 0 && i < j {
+ if rawUsername[i] == '@' {
+ client = rawUsername[i+1 : j]
+ } else {
+ network = rawUsername[i+1 : j]
+ }
+ }
+
+ return username, client, network
+}
+
+func (dc *downstreamConn) authenticate(ctx context.Context, username, password string) error {
+ username, clientName, networkName := unmarshalUsername(username)
+
+ u, err := dc.srv.db.GetUser(ctx, username)
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("user not found: %w", err))
+ }
+
+ // Password auth disabled
+ if u.Password == "" {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("password auth disabled"))
+ }
+
+ err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
+ if err != nil {
+ return newInvalidUsernameOrPasswordError(fmt.Errorf("wrong password"))
+ }
+
+ dc.user = dc.srv.getUser(username)
+ if dc.user == nil {
+ return fmt.Errorf("user not active")
+ }
+ dc.clientName = clientName
+ dc.networkName = networkName
+ return nil
+}
+
+func (dc *downstreamConn) register(ctx context.Context) error {
+ if dc.registered {
+ panic("tried to register twice")
+ }
+
+ if dc.sasl != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ }
+
+ password := dc.password
+ dc.password = ""
+ if dc.user == nil {
+ if password == "" {
+ if dc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"*", "ACCOUNT_REQUIRED", "Authentication required"},
+ }}
+ } else {
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, "Authentication required"},
+ }}
+ }
+ }
+
+ if err := dc.authenticate(ctx, dc.rawUsername, password); err != nil {
+ dc.logger.Printf("PASS authentication error for user %q: %v", dc.rawUsername, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, authErrorReason(err)},
+ }}
+ }
+ }
+
+ _, fallbackClientName, fallbackNetworkName := unmarshalUsername(dc.rawUsername)
+ if dc.clientName == "" {
+ dc.clientName = fallbackClientName
+ } else if fallbackClientName != "" && dc.clientName != fallbackClientName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Client name mismatch in usernames"},
+ }}
+ }
+ if dc.networkName == "" {
+ dc.networkName = fallbackNetworkName
+ } else if fallbackNetworkName != "" && dc.networkName != fallbackNetworkName {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, "Network name mismatch in usernames"},
+ }}
+ }
+
+ dc.registered = true
+ dc.logger.Printf("registration complete for user %q", dc.user.Username)
+ return nil
+}
+
+func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
+ if dc.networkName == "" {
+ return nil
+ }
+
+ network := dc.user.getNetwork(dc.networkName)
+ if network == nil {
+ addr := dc.networkName
+ if !strings.ContainsRune(addr, ':') {
+ addr = addr + ":6697"
+ }
+
+ dc.logger.Printf("trying to connect to new network %q", addr)
+ if err := sanityCheckServer(ctx, addr); err != nil {
+ dc.logger.Printf("failed to connect to %q: %v", addr, err)
+ return ircError{&irc.Message{
+ Command: irc.ERR_PASSWDMISMATCH,
+ Params: []string{dc.nick, fmt.Sprintf("Failed to connect to %q", dc.networkName)},
+ }}
+ }
+
+ // Some clients only allow specifying the nickname (and use the
+ // nickname as a username too). Strip the network name from the
+ // nickname when auto-saving networks.
+ nick, _, _ := unmarshalUsername(dc.nick)
+
+ dc.logger.Printf("auto-saving network %q", dc.networkName)
+ var err error
+ network, err = dc.user.createNetwork(ctx, &Network{
+ Addr: dc.networkName,
+ Nick: nick,
+ Enabled: true,
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ dc.network = network
+ return nil
+}
+
+func (dc *downstreamConn) welcome(ctx context.Context) error {
+ if dc.user == nil || !dc.registered {
+ panic("tried to welcome an unregistered connection")
+ }
+
+ remoteAddr := dc.conn.RemoteAddr().String()
+ dc.logger = &prefixLogger{dc.srv.Logger, fmt.Sprintf("user %q: downstream %q: ", dc.user.Username, remoteAddr)}
+
+ // TODO: doing this might take some time. We should do it in dc.register
+ // instead, but we'll potentially be adding a new network and this must be
+ // done in the user goroutine.
+ if err := dc.loadNetwork(ctx); err != nil {
+ return err
+ }
+
+ if dc.network == nil && !dc.caps["soju.im/bouncer-networks"] && dc.srv.Config().MultiUpstream {
+ dc.isMultiUpstream = true
+ }
+
+ dc.updateSupportedCaps()
+
+ isupport := []string{
+ fmt.Sprintf("CHATHISTORY=%v", chatHistoryLimit),
+ "CASEMAPPING=ascii",
+ }
+
+ if dc.network != nil {
+ isupport = append(isupport, fmt.Sprintf("BOUNCER_NETID=%v", dc.network.ID))
+ }
+ if title := dc.srv.Config().Title; dc.network == nil && title != "" {
+ isupport = append(isupport, "NETWORK="+encodeISUPPORT(title))
+ }
+ if dc.network == nil && !dc.isMultiUpstream {
+ isupport = append(isupport, "WHOX")
+ }
+
+ if uc := dc.upstream(); uc != nil {
+ for k := range passthroughIsupport {
+ v, ok := uc.isupport[k]
+ if !ok {
+ continue
+ }
+ if v != nil {
+ isupport = append(isupport, fmt.Sprintf("%v=%v", k, *v))
+ } else {
+ isupport = append(isupport, k)
+ }
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WELCOME,
+ Params: []string{dc.nick, "Welcome to suika, " + dc.nick},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_YOURHOST,
+ Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MYINFO,
+ Params: []string{dc.nick, dc.srv.Config().Hostname, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ for _, msg := range generateIsupport(dc.srv.prefix(), dc.nick, isupport) {
+ dc.SendMessage(msg)
+ }
+ if uc := dc.upstream(); uc != nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + string(uc.modes)},
+ })
+ }
+ if dc.network == nil && !dc.isMultiUpstream && dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+o"},
+ })
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+
+ if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil {
+ for _, msg := range generateMOTD(dc.srv.prefix(), dc.nick, motd) {
+ dc.SendMessage(msg)
+ }
+ } else {
+ motdHint := "No MOTD"
+ if dc.network != nil {
+ motdHint = "Use /motd to read the message of the day"
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOMOTD,
+ Params: []string{dc.nick, motdHint},
+ })
+ }
+
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if !ch.complete {
+ continue
+ }
+ record := uc.network.channels.Value(ch.Name)
+ if record != nil && record.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)},
+ })
+
+ forwardChannel(ctx, dc, ch)
+ }
+ })
+
+ dc.forEachNetwork(func(net *network) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ // Only send history if we're the first connected client with that name
+ // for the network
+ firstClient := true
+ dc.user.forEachDownstream(func(c *downstreamConn) {
+ if c != dc && c.clientName == dc.clientName && c.network == dc.network {
+ firstClient = false
+ }
+ })
+ if firstClient {
+ net.delivered.ForEachTarget(func(target string) {
+ lastDelivered := net.delivered.LoadID(target, dc.clientName)
+ if lastDelivered == "" {
+ return
+ }
+
+ dc.sendTargetBacklog(ctx, net, target, lastDelivered)
+
+ // Fast-forward history to last message
+ targetCM := net.casemap(target)
+ lastID, err := dc.user.msgStore.LastMsgID(&net.Network, targetCM, time.Now())
+ if err != nil {
+ dc.logger.Printf("failed to get last message ID: %v", err)
+ return
+ }
+ net.delivered.StoreID(target, dc.clientName, lastID)
+ })
+ }
+ })
+
+ return nil
+}
+
+// messageSupportsBacklog checks whether the provided message can be sent as
+// part of an history batch.
+func (dc *downstreamConn) messageSupportsBacklog(msg *irc.Message) bool {
+ // Don't replay all messages, because that would mess up client
+ // state. For instance we just sent the list of users, sending
+ // PART messages for one of these users would be incorrect.
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ return true
+ }
+ return false
+}
+
+func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, target, msgID string) {
+ if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
+ return
+ }
+
+ ch := net.channels.Value(target)
+
+ ctx, cancel := context.WithTimeout(ctx, backlogTimeout)
+ defer cancel()
+
+ targetCM := net.casemap(target)
+ history, err := dc.user.msgStore.LoadLatestID(ctx, &net.Network, targetCM, msgID, backlogLimit)
+ if err != nil {
+ dc.logger.Printf("failed to send backlog for %q: %v", target, err)
+ return
+ }
+
+ dc.SendBatch("chathistory", []string{dc.marshalEntity(net, target)}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ if ch != nil && ch.Detached {
+ if net.detachedMessageNeedsRelay(ch, msg) {
+ dc.relayDetachedMessage(net, msg)
+ }
+ } else {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, net))
+ }
+ }
+ })
+}
+
+func (dc *downstreamConn) relayDetachedMessage(net *network, msg *irc.Message) {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return
+ }
+
+ sender := msg.Prefix.Name
+ target, text := msg.Params[0], msg.Params[1]
+ if net.isHighlight(msg) {
+ sendServiceNOTICE(dc, fmt.Sprintf("highlight in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ } else {
+ sendServiceNOTICE(dc, fmt.Sprintf("message in %v: <%v> %v", dc.marshalEntity(net, target), sender, text))
+ }
+}
+
+func (dc *downstreamConn) runUntilRegistered() error {
+ ctx, cancel := context.WithTimeout(context.TODO(), downstreamRegisterTimeout)
+ defer cancel()
+
+ // Close the connection with an error if the deadline is exceeded
+ go func() {
+ <-ctx.Done()
+ if err := ctx.Err(); err == context.DeadlineExceeded {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "ERROR",
+ Params: []string{"Connection registration timed out"},
+ })
+ dc.Close()
+ }
+ }()
+
+ for !dc.registered {
+ msg, err := dc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read IRC command: %w", err)
+ }
+
+ err = dc.handleMessage(ctx, msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ return fmt.Errorf("failed to handle IRC command %q: %v", msg, err)
+ }
+ }
+
+ return nil
+}
+
+func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error {
+ switch msg.Command {
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, &subCmd); err != nil {
+ return err
+ }
+ if err := dc.handleCapCommand(subCmd, msg.Params[1:]); err != nil {
+ return err
+ }
+ case "PING":
+ var source, destination string
+ if err := parseMessageParams(msg, &source); err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ destination = msg.Params[1]
+ }
+ hostname := dc.srv.Config().Hostname
+ if destination != "" && destination != hostname {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHSERVER,
+ Params: []string{dc.nick, destination, "No such server"},
+ }}
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "PONG",
+ Params: []string{hostname, source},
+ })
+ return nil
+ case "PONG":
+ if len(msg.Params) == 0 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ token := msg.Params[len(msg.Params)-1]
+ dc.handlePong(token)
+ case "USER":
+ return ircError{&irc.Message{
+ Command: irc.ERR_ALREADYREGISTERED,
+ Params: []string{dc.nick, "You may not reregister"},
+ }}
+ case "NICK":
+ var rawNick string
+ if err := parseMessageParams(msg, &rawNick); err != nil {
+ return err
+ }
+
+ nick := rawNick
+ var upstream *upstreamConn
+ if dc.upstream() == nil {
+ uc, unmarshaledNick, err := dc.unmarshalEntity(nick)
+ if err == nil { // NICK nick/network: NICK only on a specific upstream
+ upstream = uc
+ nick = unmarshaledNick
+ }
+ }
+
+ if nick == "" || strings.ContainsAny(nick, illegalNickChars) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_ERRONEUSNICKNAME,
+ Params: []string{dc.nick, rawNick, "contains illegal characters"},
+ }}
+ }
+ if casemapASCII(nick) == serviceNickCM {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NICKNAMEINUSE,
+ Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"},
+ }}
+ }
+
+ var err error
+ dc.forEachNetwork(func(n *network) {
+ if err != nil || (upstream != nil && upstream.network != n) {
+ return
+ }
+ n.Nick = nick
+ err = dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network)
+ })
+ if err != nil {
+ return err
+ }
+
+ dc.forEachUpstream(func(uc *upstreamConn) {
+ if upstream != nil && upstream != uc {
+ return
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ })
+
+ if dc.upstream() == nil && upstream == nil && dc.nick != nick {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "NICK",
+ Params: []string{nick},
+ })
+ dc.nick = nick
+ dc.nickCM = casemapASCII(dc.nick)
+ }
+ case "SETNAME":
+ var realname string
+ if err := parseMessageParams(msg, &realname); err != nil {
+ return err
+ }
+
+ // If the client just resets to the default, just wipe the per-network
+ // preference
+ storeRealname := realname
+ if realname == dc.user.Realname {
+ storeRealname = ""
+ }
+
+ var storeErr error
+ var needUpdate []Network
+ dc.forEachNetwork(func(n *network) {
+ // We only need to call updateNetwork for upstreams that don't
+ // support setname
+ if uc := n.conn; uc != nil && uc.caps["setname"] {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+
+ n.Realname = storeRealname
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil {
+ dc.logger.Printf("failed to store network realname: %v", err)
+ storeErr = err
+ }
+ return
+ }
+
+ record := n.Network // copy network record because we'll mutate it
+ record.Realname = storeRealname
+ needUpdate = append(needUpdate, record)
+ })
+
+ // Walk the network list as a second step, because updateNetwork
+ // mutates the original list
+ for _, record := range needUpdate {
+ if _, err := dc.user.updateNetwork(ctx, &record); err != nil {
+ dc.logger.Printf("failed to update network realname: %v", err)
+ storeErr = err
+ }
+ }
+ if storeErr != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"SETNAME", "CANNOT_CHANGE_REALNAME", "Failed to update realname"},
+ }}
+ }
+
+ if dc.upstream() == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "SETNAME",
+ Params: []string{realname},
+ })
+ }
+ case "JOIN":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var keys []string
+ if len(msg.Params) > 1 {
+ keys = strings.Split(msg.Params[1], ",")
+ }
+
+ for i, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ var key string
+ if len(keys) > i {
+ key = keys[i]
+ }
+
+ if !uc.isChannel(upstreamName) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{name, "Not a channel name"},
+ })
+ continue
+ }
+
+ // Most servers ignore duplicate JOIN messages. We ignore them here
+ // because some clients automatically send JOIN messages in bulk
+ // when reconnecting to the bouncer. We don't want to flood the
+ // upstream connection with these.
+ if !uc.channels.Has(upstreamName) {
+ params := []string{upstreamName}
+ if key != "" {
+ params = append(params, key)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "JOIN",
+ Params: params,
+ })
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ // Don't clear the channel key if there's one set
+ // TODO: add a way to unset the channel key
+ if key != "" {
+ ch.Key = key
+ }
+ uc.network.attach(ctx, ch)
+ } else {
+ ch = &Channel{
+ Name: upstreamName,
+ Key: key,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ }
+ case "PART":
+ var namesStr string
+ if err := parseMessageParams(msg, &namesStr); err != nil {
+ return err
+ }
+
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+
+ for _, name := range strings.Split(namesStr, ",") {
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if strings.EqualFold(reason, "detach") {
+ ch := uc.network.channels.Value(upstreamName)
+ if ch != nil {
+ uc.network.detach(ch)
+ } else {
+ ch = &Channel{
+ Name: name,
+ Detached: true,
+ }
+ uc.network.channels.SetValue(upstreamName, ch)
+ }
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
+ }
+ } else {
+ params := []string{upstreamName}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "PART",
+ Params: params,
+ })
+
+ if err := uc.network.deleteChannel(ctx, upstreamName); err != nil {
+ dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err)
+ }
+ }
+ }
+ case "KICK":
+ var channelStr, userStr string
+ if err := parseMessageParams(msg, &channelStr, &userStr); err != nil {
+ return err
+ }
+
+ channels := strings.Split(channelStr, ",")
+ users := strings.Split(userStr, ",")
+
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+
+ if len(channels) != 1 && len(channels) != len(users) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_BADCHANMASK,
+ Params: []string{dc.nick, channelStr, "Bad channel mask"},
+ }}
+ }
+
+ for i, user := range users {
+ var channel string
+ if len(channels) == 1 {
+ channel = channels[0]
+ } else {
+ channel = channels[i]
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ params := []string{upstreamChannel, upstreamUser}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "KICK",
+ Params: params,
+ })
+ }
+ case "MODE":
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+
+ var modeStr string
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ if casemapASCII(name) == dc.nickCM {
+ if modeStr != "" {
+ if uc := dc.upstream(); uc != nil {
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: []string{uc.nick, modeStr},
+ })
+ } else {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_UMODEUNKNOWNFLAG,
+ Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"},
+ })
+ }
+ } else {
+ var userMode string
+ if uc := dc.upstream(); uc != nil {
+ userMode = string(uc.modes)
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_UMODEIS,
+ Params: []string{dc.nick, "+" + userMode},
+ })
+ }
+ return nil
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if !uc.isChannel(upstreamName) {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERSDONTMATCH,
+ Params: []string{dc.nick, "Cannot change mode for other users"},
+ }}
+ }
+
+ if modeStr != "" {
+ params := []string{upstreamName, modeStr}
+ params = append(params, msg.Params[2:]...)
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "MODE",
+ Params: params,
+ })
+ } else {
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, name, "No such channel"},
+ }}
+ }
+
+ if ch.modes == nil {
+ // we haven't received the initial RPL_CHANNELMODEIS yet
+ // ignore the request, we will broadcast the modes later when we receive RPL_CHANNELMODEIS
+ return nil
+ }
+
+ modeStr, modeParams := ch.modes.Format()
+ params := []string{dc.nick, name, modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ if ch.creationTime != "" {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, name, ch.creationTime},
+ })
+ }
+ }
+ case "TOPIC":
+ var channel string
+ if err := parseMessageParams(msg, &channel); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ if len(msg.Params) > 1 { // setting topic
+ topic := msg.Params[1]
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "TOPIC",
+ Params: []string{upstreamName, topic},
+ })
+ } else { // getting topic
+ ch := uc.channels.Value(upstreamName)
+ if ch == nil {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NOSUCHCHANNEL,
+ Params: []string{dc.nick, upstreamName, "No such channel"},
+ }}
+ }
+ sendTopic(dc, ch)
+ }
+ case "LIST":
+ network := dc.network
+ if network == nil && len(msg.Params) > 0 {
+ var err error
+ network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0])
+ if err != nil {
+ return err
+ }
+ }
+ if network == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"},
+ })
+ return nil
+ }
+
+ uc := network.conn
+ if uc == nil {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Disconnected from upstream server"},
+ })
+ return nil
+ }
+
+ uc.enqueueCommand(dc, msg)
+ case "NAMES":
+ if len(msg.Params) == 0 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, "*", "End of /NAMES list"},
+ })
+ return nil
+ }
+
+ channels := strings.Split(msg.Params[0], ",")
+ for _, channel := range channels {
+ uc, upstreamName, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(upstreamName)
+ if ch != nil {
+ sendNames(dc, ch)
+ } else {
+ // NAMES on a channel we have not joined, ask upstream
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "NAMES",
+ Params: []string{upstreamName},
+ })
+ }
+ }
+ // For WHOX docs, see:
+ // - http://faerion.sourceforge.net/doc/irc/whox.var
+ // - https://github.com/quakenet/snircd/blob/master/doc/readme.who
+ // Note, many features aren't widely implemented, such as flags and mask2
+ case "WHO":
+ if len(msg.Params) == 0 {
+ // TODO: support WHO without parameters
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, "*", "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // Clients will use the first mask to match RPL_ENDOFWHO
+ endOfWhoToken := msg.Params[0]
+
+ // TODO: add support for WHOX mask2
+ mask := msg.Params[0]
+ var options string
+ if len(msg.Params) > 1 {
+ options = msg.Params[1]
+ }
+
+ optionsParts := strings.SplitN(options, "%", 2)
+ // TODO: add support for WHOX flags in optionsParts[0]
+ var fields, whoxToken string
+ if len(optionsParts) == 2 {
+ optionsParts := strings.SplitN(optionsParts[1], ",", 2)
+ fields = strings.ToLower(optionsParts[0])
+ if len(optionsParts) == 2 && strings.Contains(fields, "t") {
+ whoxToken = optionsParts[1]
+ }
+ }
+
+ // TODO: support mixed bouncer/upstream WHO queries
+ maskCM := casemapASCII(mask)
+ if dc.network == nil && maskCM == dc.nickCM {
+ // TODO: support AWAY (H/G) in self WHO reply
+ flags := "H"
+ if dc.user.Admin {
+ flags += "*"
+ }
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: dc.user.Username,
+ Hostname: dc.hostname,
+ Server: dc.srv.Config().Hostname,
+ Nickname: dc.nick,
+ Flags: flags,
+ Account: dc.user.Username,
+ Realname: dc.realname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+ if maskCM == serviceNickCM {
+ info := whoxInfo{
+ Token: whoxToken,
+ Username: servicePrefix.User,
+ Hostname: servicePrefix.Host,
+ Server: dc.srv.Config().Hostname,
+ Nickname: serviceNick,
+ Flags: "H*",
+ Account: serviceNick,
+ Realname: serviceRealname,
+ }
+ dc.SendMessage(generateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info))
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"},
+ })
+ return nil
+ }
+
+ // TODO: properly support WHO masks
+ uc, upstreamMask, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ params := []string{upstreamMask}
+ if options != "" {
+ params = append(params, options)
+ }
+
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "WHO",
+ Params: params,
+ })
+ case "WHOIS":
+ if len(msg.Params) == 0 {
+ return ircError{&irc.Message{
+ Command: irc.ERR_NONICKNAMEGIVEN,
+ Params: []string{dc.nick, "No nickname given"},
+ }}
+ }
+
+ var target, mask string
+ if len(msg.Params) == 1 {
+ target = ""
+ mask = msg.Params[0]
+ } else {
+ target = msg.Params[0]
+ mask = msg.Params[1]
+ }
+ // TODO: support multiple WHOIS users
+ if i := strings.IndexByte(mask, ','); i >= 0 {
+ mask = mask[:i]
+ }
+
+ if dc.network == nil && casemapASCII(mask) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "suika"},
+ })
+ if dc.user.Admin {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, dc.nick, "is a bouncer administrator"},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, dc.nick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+ if casemapASCII(mask) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "suika"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, serviceNick, "is the bouncer service"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_whoisaccount,
+ Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"},
+ })
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, serviceNick, "End of /WHOIS list"},
+ })
+ return nil
+ }
+
+ // TODO: support WHOIS masks
+ uc, upstreamNick, err := dc.unmarshalEntity(mask)
+ if err != nil {
+ return err
+ }
+
+ var params []string
+ if target != "" {
+ if target == mask { // WHOIS nick nick
+ params = []string{upstreamNick, upstreamNick}
+ } else {
+ params = []string{target, upstreamNick}
+ }
+ } else {
+ params = []string{upstreamNick}
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "WHOIS",
+ Params: params,
+ })
+ case "PRIVMSG", "NOTICE":
+ var targetsStr, text string
+ if err := parseMessageParams(msg, &targetsStr, &text); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if name == "$"+dc.srv.Config().Hostname || (name == "$*" && dc.network == nil) {
+ // "$" means a server mask follows. If it's the bouncer's
+ // hostname, broadcast the message to all bouncer users.
+ if !dc.user.Admin {
+ return ircError{&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_BADMASK,
+ Params: []string{dc.nick, name, "Permission denied to broadcast message to all bouncer users"},
+ }}
+ }
+
+ dc.logger.Printf("broadcasting bouncer-wide %v: %v", msg.Command, text)
+
+ broadcastTags := tags.Copy()
+ broadcastTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ broadcastMsg := &irc.Message{
+ Tags: broadcastTags,
+ Prefix: servicePrefix,
+ Command: msg.Command,
+ Params: []string{name, text},
+ }
+ dc.srv.forEachUser(func(u *user) {
+ u.events <- eventBroadcast{broadcastMsg}
+ })
+ continue
+ }
+
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ continue
+ }
+
+ if msg.Command == "PRIVMSG" && casemapASCII(name) == serviceNickCM {
+ if dc.caps["echo-message"] {
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ dc.SendMessage(&irc.Message{
+ Tags: echoTags,
+ Prefix: dc.prefix(),
+ Command: msg.Command,
+ Params: []string{name, text},
+ })
+ }
+ handleServicePRIVMSG(ctx, dc, text)
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+
+ if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" {
+ dc.handleNickServPRIVMSG(ctx, uc, text)
+ }
+
+ unmarshaledText := text
+ if uc.isChannel(upstreamName) {
+ unmarshaledText = dc.unmarshalText(uc, text)
+ }
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: msg.Command,
+ Params: []string{upstreamName, unmarshaledText},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: msg.Command,
+ Params: []string{upstreamName, text},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "TAGMSG":
+ var targetsStr string
+ if err := parseMessageParams(msg, &targetsStr); err != nil {
+ return err
+ }
+ tags := copyClientTags(msg.Tags)
+
+ for _, name := range strings.Split(targetsStr, ",") {
+ if dc.network == nil && casemapASCII(name) == dc.nickCM {
+ dc.SendMessage(&irc.Message{
+ Tags: msg.Tags.Copy(),
+ Prefix: dc.prefix(),
+ Command: "TAGMSG",
+ Params: []string{name},
+ })
+ continue
+ }
+
+ if casemapASCII(name) == serviceNickCM {
+ continue
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return err
+ }
+ if _, ok := uc.caps["message-tags"]; !ok {
+ continue
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Tags: tags,
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ })
+
+ echoTags := tags.Copy()
+ echoTags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ if uc.account != "" {
+ echoTags["account"] = irc.TagValue(uc.account)
+ }
+ echoMsg := &irc.Message{
+ Tags: echoTags,
+ Prefix: &irc.Prefix{Name: uc.nick},
+ Command: "TAGMSG",
+ Params: []string{upstreamName},
+ }
+ uc.produce(upstreamName, echoMsg, dc)
+
+ uc.updateChannelAutoDetach(upstreamName)
+ }
+ case "INVITE":
+ var user, channel string
+ if err := parseMessageParams(msg, &user, &channel); err != nil {
+ return err
+ }
+
+ ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
+ if err != nil {
+ return err
+ }
+
+ ucUser, upstreamUser, err := dc.unmarshalEntity(user)
+ if err != nil {
+ return err
+ }
+
+ if ucChannel != ucUser {
+ return ircError{&irc.Message{
+ Command: irc.ERR_USERNOTINCHANNEL,
+ Params: []string{dc.nick, user, channel, "They are on another network"},
+ }}
+ }
+ uc := ucChannel
+
+ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
+ Command: "INVITE",
+ Params: []string{upstreamUser, upstreamChannel},
+ })
+ case "AUTHENTICATE":
+ // Post-connection-registration AUTHENTICATE is unsupported in
+ // multi-upstream mode, or if the upstream doesn't support SASL
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["sasl"] {
+ return ircError{&irc.Message{
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Upstream network authentication not supported"},
+ }}
+ }
+
+ credentials, err := dc.handleAuthenticateCommand(msg)
+ if err != nil {
+ return err
+ }
+
+ if credentials != nil {
+ if uc.saslClient != nil {
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLFAIL,
+ Params: []string{dc.nick, "Another authentication attempt is already in progress"},
+ })
+ return nil
+ }
+
+ uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername)
+ uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword)
+ uc.enqueueCommand(dc, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"PLAIN"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ // Check number of params here, since we'll use that to save the
+ // credentials on command success
+ if (msg.Command == "REGISTER" && len(msg.Params) < 3) || (msg.Command == "VERIFY" && len(msg.Params) < 2) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+
+ uc := dc.upstream()
+ if uc == nil || !uc.caps["draft/account-registration"] {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{msg.Command, "TEMPORARILY_UNAVAILABLE", "*", "Upstream network account registration not supported"},
+ }}
+ }
+
+ uc.logger.Printf("starting %v with account name %v", msg.Command, msg.Params[0])
+ uc.enqueueCommand(dc, msg)
+ case "MONITOR":
+ // MONITOR is unsupported in multi-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "+", "-":
+ var targets string
+ if err := parseMessageParams(msg, nil, &targets); err != nil {
+ return err
+ }
+ for _, target := range strings.Split(targets, ",") {
+ if subcommand == "+" {
+ // Hard limit, just to avoid having downstreams fill our map
+ if len(dc.monitored.innerMap) >= 1000 {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_MONLISTFULL,
+ Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"},
+ })
+ continue
+ }
+
+ dc.monitored.SetValue(target, nil)
+
+ if uc.monitored.Has(target) {
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ } else {
+ dc.monitored.Delete(target)
+ }
+ }
+ uc.updateMonitor()
+ case "C": // clear
+ dc.monitored = newCasemapMap(0)
+ uc.updateMonitor()
+ case "L": // list
+ // TODO: be less lazy and pack the list
+ for _, entry := range dc.monitored.innerMap {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_MONLIST,
+ Params: []string{dc.nick, entry.originalKey},
+ })
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFMONLIST,
+ Params: []string{dc.nick, "End of MONITOR list"},
+ })
+ case "S": // status
+ // TODO: be less lazy and pack the lists
+ for _, entry := range dc.monitored.innerMap {
+ target := entry.originalKey
+
+ cmd := irc.RPL_MONOFFLINE
+ if online := uc.monitored.Value(target); online {
+ cmd = irc.RPL_MONONLINE
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: cmd,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ case "CHATHISTORY":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+ var target, limitStr string
+ var boundsStr [2]string
+ switch subcommand {
+ case "AFTER", "BEFORE", "LATEST":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &limitStr); err != nil {
+ return err
+ }
+ case "BETWEEN":
+ if err := parseMessageParams(msg, nil, &target, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ case "TARGETS":
+ if dc.network == nil {
+ // Either an unbound bouncer network, in which case we should return no targets,
+ // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet.
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+ if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil {
+ return err
+ }
+ default:
+ // TODO: support AROUND
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, "Unknown command"},
+ }}
+ }
+
+ // We don't save history for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {})
+ return nil
+ }
+
+ store, ok := dc.user.msgStore.(chatHistoryMessageStore)
+ if !ok {
+ return ircError{&irc.Message{
+ Command: irc.ERR_UNKNOWNCOMMAND,
+ Params: []string{dc.nick, "CHATHISTORY", "Unknown command"},
+ }}
+ }
+
+ network, entity, err := dc.unmarshalEntityNetwork(target)
+ if err != nil {
+ return err
+ }
+ entity = network.casemap(entity)
+
+ // TODO: support msgid criteria
+ var bounds [2]time.Time
+ bounds[0] = parseChatHistoryBound(boundsStr[0])
+ if subcommand == "LATEST" && boundsStr[0] == "*" {
+ bounds[0] = time.Now()
+ } else if bounds[0].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[0], "Invalid first bound"},
+ }}
+ }
+
+ if boundsStr[1] != "" {
+ bounds[1] = parseChatHistoryBound(boundsStr[1])
+ if bounds[1].IsZero() {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, boundsStr[1], "Invalid second bound"},
+ }}
+ }
+ }
+
+ limit, err := strconv.Atoi(limitStr)
+ if err != nil || limit < 0 || limit > chatHistoryLimit {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "INVALID_PARAMS", subcommand, limitStr, "Invalid limit"},
+ }}
+ }
+
+ eventPlayback := dc.caps["draft/event-playback"]
+
+ var history []*irc.Message
+ switch subcommand {
+ case "BEFORE", "LATEST":
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], time.Time{}, limit, eventPlayback)
+ case "AFTER":
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], time.Now(), limit, eventPlayback)
+ case "BETWEEN":
+ if bounds[0].Before(bounds[1]) {
+ history, err = store.LoadAfterTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ } else {
+ history, err = store.LoadBeforeTime(ctx, &network.Network, entity, bounds[0], bounds[1], limit, eventPlayback)
+ }
+ case "TARGETS":
+ // TODO: support TARGETS in multi-upstream mode
+ targets, err := store.ListTargets(ctx, &network.Network, bounds[0], bounds[1], limit, eventPlayback)
+ if err != nil {
+ dc.logger.Printf("failed fetching targets for chathistory: %v", err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"CHATHISTORY", "MESSAGE_ERROR", subcommand, "Failed to retrieve targets"},
+ }}
+ }
+
+ dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef irc.TagValue) {
+ for _, target := range targets {
+ if ch := network.channels.Value(target.Name); ch != nil && ch.Detached {
+ continue
+ }
+
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "CHATHISTORY",
+ Params: []string{"TARGETS", target.Name, formatServerTime(target.LatestMessage)},
+ })
+ }
+ })
+
+ return nil
+ }
+ if err != nil {
+ dc.logger.Printf("failed fetching %q messages for chathistory: %v", target, err)
+ return newChatHistoryError(subcommand, target)
+ }
+
+ dc.SendBatch("chathistory", []string{target}, nil, func(batchRef irc.TagValue) {
+ for _, msg := range history {
+ msg.Tags["batch"] = batchRef
+ dc.SendMessage(dc.marshalMessage(msg, network))
+ }
+ })
+ case "READ":
+ var target, criteria string
+ if err := parseMessageParams(msg, &target); err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "NEED_MORE_PARAMS", "Missing parameters"},
+ }}
+ }
+ if len(msg.Params) > 1 {
+ criteria = msg.Params[1]
+ }
+
+ // We don't save read receipts for our service
+ if casemapASCII(target) == serviceNickCM {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "READ",
+ Params: []string{target, "*"},
+ })
+ return nil
+ }
+
+ uc, entity, err := dc.unmarshalEntity(target)
+ if err != nil {
+ return err
+ }
+ entityCM := uc.network.casemap(entity)
+
+ r, err := dc.srv.db.GetReadReceipt(ctx, uc.network.ID, entityCM)
+ if err != nil {
+ dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ } else if r == nil {
+ r = &ReadReceipt{
+ Target: entityCM,
+ }
+ }
+
+ broadcast := false
+ if len(criteria) > 0 {
+ // TODO: support msgid criteria
+ criteriaParts := strings.SplitN(criteria, "=", 2)
+ if len(criteriaParts) != 2 || criteriaParts[0] != "timestamp" {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Unknown criteria"},
+ }}
+ }
+
+ timestamp, err := time.Parse(serverTimeLayout, criteriaParts[1])
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INVALID_PARAMS", criteria, "Invalid criteria"},
+ }}
+ }
+ now := time.Now()
+ if timestamp.After(now) {
+ timestamp = now
+ }
+ if r.Timestamp.Before(timestamp) {
+ r.Timestamp = timestamp
+ if err := dc.srv.db.StoreReadReceipt(ctx, uc.network.ID, r); err != nil {
+ dc.logger.Printf("failed to store receipt for %q: %v", entity, err)
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
+ }}
+ }
+ broadcast = true
+ }
+ }
+
+ timestampStr := "*"
+ if !r.Timestamp.IsZero() {
+ timestampStr = fmt.Sprintf("timestamp=%s", formatServerTime(r.Timestamp))
+ }
+ uc.forEachDownstream(func(d *downstreamConn) {
+ if broadcast || dc.id == d.id {
+ d.SendMessage(&irc.Message{
+ Prefix: d.prefix(),
+ Command: "READ",
+ Params: []string{d.marshalEntity(uc.network, entity), timestampStr},
+ })
+ }
+ })
+ case "BOUNCER":
+ var subcommand string
+ if err := parseMessageParams(msg, &subcommand); err != nil {
+ return err
+ }
+
+ switch strings.ToUpper(subcommand) {
+ case "BIND":
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"},
+ }}
+ case "LISTNETWORKS":
+ dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef irc.TagValue) {
+ for _, network := range dc.user.networks {
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ dc.SendMessage(&irc.Message{
+ Tags: irc.Tags{"batch": batchRef},
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+ case "ADDNETWORK":
+ var attrsStr string
+ if err := parseMessageParams(msg, nil, &attrsStr); err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ record := &Network{Nick: dc.nick, Enabled: true}
+ if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to create network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)},
+ })
+ case "CHANGENETWORK":
+ var idStr, attrsStr string
+ if err := parseMessageParams(msg, nil, &idStr, &attrsStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+ attrs := irc.ParseTags(attrsStr)
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := updateNetworkAttrs(&record, attrs, subcommand); err != nil {
+ return err
+ }
+
+ if record.Nick == dc.user.Username {
+ record.Nick = ""
+ }
+ if record.Realname == dc.user.Realname {
+ record.Realname = ""
+ }
+
+ _, err = dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_ERROR", subcommand, fmt.Sprintf("Failed to update network: %v", err)},
+ }}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"CHANGENETWORK", idStr},
+ })
+ case "DELNETWORK":
+ var idStr string
+ if err := parseMessageParams(msg, nil, &idStr); err != nil {
+ return err
+ }
+ id, err := parseBouncerNetID(subcommand, idStr)
+ if err != nil {
+ return err
+ }
+
+ net := dc.user.getNetworkByID(id)
+ if net == nil {
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "INVALID_NETID", subcommand, idStr, "Invalid network ID"},
+ }}
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"DELNETWORK", idStr},
+ })
+ default:
+ return ircError{&irc.Message{
+ Command: "FAIL",
+ Params: []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"},
+ }}
+ }
+ default:
+ dc.logger.Printf("unhandled message: %v", msg)
+
+ // Only forward unknown commands in single-upstream mode
+ uc := dc.upstream()
+ if uc == nil {
+ return newUnknownCommandError(msg.Command)
+ }
+
+ uc.SendMessageLabeled(ctx, dc.id, msg)
+ }
+ return nil
+}
+
+func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) {
+ username, password, ok := parseNickServCredentials(text, uc.nick)
+ if ok {
+ uc.network.autoSaveSASLPlain(ctx, username, password)
+ }
+}
+
+func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
+ fields := strings.Fields(text)
+ if len(fields) < 2 {
+ return "", "", false
+ }
+ cmd := strings.ToUpper(fields[0])
+ params := fields[1:]
+ switch cmd {
+ case "REGISTER":
+ username = nick
+ password = params[0]
+ case "IDENTIFY":
+ if len(params) == 1 {
+ username = nick
+ password = params[0]
+ } else {
+ username = params[0]
+ password = params[1]
+ }
+ case "SET":
+ if len(params) == 2 && strings.EqualFold(params[0], "PASSWORD") {
+ username = nick
+ password = params[1]
+ }
+ default:
+ return "", "", false
+ }
+ return username, password, true
+}
--- /dev/null
+module marisa.chaotic.ninja/suika
+
+go 1.20
+
+require (
+ git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99
+ git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9
+ github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead
+ github.com/lib/pq v1.10.7
+ golang.org/x/crypto v0.7.0
+ golang.org/x/term v0.6.0
+ golang.org/x/time v0.3.0
+ gopkg.in/irc.v3 v3.1.4
+ modernc.org/sqlite v1.21.0
+)
+
+require (
+ github.com/dustin/go-humanize v1.0.0 // indirect
+ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
+ github.com/google/uuid v1.3.0 // indirect
+ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
+ github.com/mattn/go-isatty v0.0.16 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/stretchr/testify v1.8.0 // indirect
+ golang.org/x/mod v0.3.0 // indirect
+ golang.org/x/sys v0.6.0 // indirect
+ golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 // indirect
+ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+ gopkg.in/yaml.v2 v2.4.0 // indirect
+ lukechampine.com/uint128 v1.2.0 // indirect
+ modernc.org/cc/v3 v3.40.0 // indirect
+ modernc.org/ccgo/v3 v3.16.13 // indirect
+ modernc.org/libc v1.22.3 // indirect
+ modernc.org/mathutil v1.5.0 // indirect
+ modernc.org/memory v1.5.0 // indirect
+ modernc.org/opt v0.1.3 // indirect
+ modernc.org/strutil v1.1.3 // indirect
+ modernc.org/token v1.0.1 // indirect
+)
--- /dev/null
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99 h1:1s8n5uisqkR+BzPgaum6xxIjKmzGrTykJdh+Y3f5Xao=
+git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99/go.mod h1:t+Ww6SR24yYnXzEWiNlOY0AFo5E9B73X++10lrSpp4U=
+git.sr.ht/~sircmpwn/getopt v0.0.0-20191230200459-23622cc906b3/go.mod h1:wMEGFFFNuPos7vHmWXfszqImLppbc0wEhh6JBfJIUgw=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMAlt8utUFKhhxJtwBAualvsbc/Sk7cE=
+git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead h1:fI1Jck0vUrXT8bnphprS1EoVRe2Q5CKCX8iDlpqjQ/Y=
+github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
+github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
+github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
+github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
+github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
+github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
+github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
+golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
+golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
+golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
+golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
+golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs=
+golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/irc.v3 v3.1.4 h1:DYGMRFbtseXEh+NadmMUFzMraqyuUj4I3iWYFEzDZPc=
+gopkg.in/irc.v3 v3.1.4/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI=
+lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
+modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw=
+modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0=
+modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw=
+modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY=
+modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk=
+modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM=
+modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY=
+modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw=
+modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
+modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
+modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
+modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
+modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
+modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
+modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow=
+modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI=
+modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
+modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw=
+modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws=
+modernc.org/token v1.0.1 h1:A3qvTqOwexpfZZeyI0FeGPDlSWX5pjZu9hF4lU+EKWg=
+modernc.org/token v1.0.1/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
+modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE=
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+ "unicode"
+ "unicode/utf8"
+
+ "gopkg.in/irc.v3"
+)
+
+const (
+ rpl_statsping = "246"
+ rpl_localusers = "265"
+ rpl_globalusers = "266"
+ rpl_creationtime = "329"
+ rpl_topicwhotime = "333"
+ rpl_whospcrpl = "354"
+ rpl_whoisaccount = "330"
+ err_invalidcapcmd = "410"
+)
+
+const (
+ maxMessageLength = 512
+ maxMessageParams = 15
+ maxSASLLength = 400
+)
+
+// The server-time layout, as defined in the IRCv3 spec.
+const serverTimeLayout = "2006-01-02T15:04:05.000Z"
+
+func formatServerTime(t time.Time) string {
+ return t.UTC().Format(serverTimeLayout)
+}
+
+type userModes string
+
+func (ms userModes) Has(c byte) bool {
+ return strings.IndexByte(string(ms), c) >= 0
+}
+
+func (ms *userModes) Add(c byte) {
+ if !ms.Has(c) {
+ *ms += userModes(c)
+ }
+}
+
+func (ms *userModes) Del(c byte) {
+ i := strings.IndexByte(string(*ms), c)
+ if i >= 0 {
+ *ms = (*ms)[:i] + (*ms)[i+1:]
+ }
+}
+
+func (ms *userModes) Apply(s string) error {
+ var plusMinus byte
+ for i := 0; i < len(s); i++ {
+ switch c := s[i]; c {
+ case '+', '-':
+ plusMinus = c
+ default:
+ switch plusMinus {
+ case '+':
+ ms.Add(c)
+ case '-':
+ ms.Del(c)
+ default:
+ return fmt.Errorf("malformed modestring %q: missing plus/minus", s)
+ }
+ }
+ }
+ return nil
+}
+
+type channelModeType byte
+
+// standard channel mode types, as explained in https://modern.ircdocs.horse/#mode-message
+const (
+ // modes that add or remove an address to or from a list
+ modeTypeA channelModeType = iota
+ // modes that change a setting on a channel, and must always have a parameter
+ modeTypeB
+ // modes that change a setting on a channel, and must have a parameter when being set, and no parameter when being unset
+ modeTypeC
+ // modes that change a setting on a channel, and must not have a parameter
+ modeTypeD
+)
+
+var stdChannelModes = map[byte]channelModeType{
+ 'b': modeTypeA, // ban list
+ 'e': modeTypeA, // ban exception list
+ 'I': modeTypeA, // invite exception list
+ 'k': modeTypeB, // channel key
+ 'l': modeTypeC, // channel user limit
+ 'i': modeTypeD, // channel is invite-only
+ 'm': modeTypeD, // channel is moderated
+ 'n': modeTypeD, // channel has no external messages
+ 's': modeTypeD, // channel is secret
+ 't': modeTypeD, // channel has protected topic
+}
+
+type channelModes map[byte]string
+
+// applyChannelModes parses a mode string and mode arguments from a MODE message,
+// and applies the corresponding channel mode and user membership changes on that channel.
+//
+// If ch.modes is nil, channel modes are not updated.
+//
+// needMarshaling is a list of indexes of mode arguments that represent entities
+// that must be marshaled when sent downstream.
+func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) {
+ needMarshaling = make(map[int]struct{}, len(arguments))
+ nextArgument := 0
+ var plusMinus byte
+outer:
+ for i := 0; i < len(modeStr); i++ {
+ mode := modeStr[i]
+ if mode == '+' || mode == '-' {
+ plusMinus = mode
+ continue
+ }
+ if plusMinus != '+' && plusMinus != '-' {
+ return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr)
+ }
+
+ for _, membership := range ch.conn.availableMemberships {
+ if membership.Mode == mode {
+ if nextArgument >= len(arguments) {
+ return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
+ }
+ member := arguments[nextArgument]
+ m := ch.Members.Value(member)
+ if m != nil {
+ if plusMinus == '+' {
+ m.Add(ch.conn.availableMemberships, membership)
+ } else {
+ // TODO: for upstreams without multi-prefix, query the user modes again
+ m.Remove(membership)
+ }
+ }
+ needMarshaling[nextArgument] = struct{}{}
+ nextArgument++
+ continue outer
+ }
+ }
+
+ mt, ok := ch.conn.availableChannelModes[mode]
+ if !ok {
+ continue
+ }
+ if mt == modeTypeA {
+ nextArgument++
+ } else if mt == modeTypeB || (mt == modeTypeC && plusMinus == '+') {
+ if plusMinus == '+' {
+ var argument string
+ // some sentitive arguments (such as channel keys) can be omitted for privacy
+ // (this will only happen for RPL_CHANNELMODEIS, never for MODE messages)
+ if nextArgument < len(arguments) {
+ argument = arguments[nextArgument]
+ }
+ if ch.modes != nil {
+ ch.modes[mode] = argument
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ nextArgument++
+ } else if mt == modeTypeC || mt == modeTypeD {
+ if plusMinus == '+' {
+ if ch.modes != nil {
+ ch.modes[mode] = ""
+ }
+ } else {
+ delete(ch.modes, mode)
+ }
+ }
+ }
+ return needMarshaling, nil
+}
+
+func (cm channelModes) Format() (modeString string, parameters []string) {
+ var modesWithValues strings.Builder
+ var modesWithoutValues strings.Builder
+ parameters = make([]string, 0, 16)
+ for mode, value := range cm {
+ if value != "" {
+ modesWithValues.WriteString(string(mode))
+ parameters = append(parameters, value)
+ } else {
+ modesWithoutValues.WriteString(string(mode))
+ }
+ }
+ modeString = "+" + modesWithValues.String() + modesWithoutValues.String()
+ return
+}
+
+const stdChannelTypes = "#&+!"
+
+type channelStatus byte
+
+const (
+ channelPublic channelStatus = '='
+ channelSecret channelStatus = '@'
+ channelPrivate channelStatus = '*'
+)
+
+func parseChannelStatus(s string) (channelStatus, error) {
+ if len(s) > 1 {
+ return 0, fmt.Errorf("invalid channel status %q: more than one character", s)
+ }
+ switch cs := channelStatus(s[0]); cs {
+ case channelPublic, channelSecret, channelPrivate:
+ return cs, nil
+ default:
+ return 0, fmt.Errorf("invalid channel status %q: unknown status", s)
+ }
+}
+
+type membership struct {
+ Mode byte
+ Prefix byte
+}
+
+var stdMemberships = []membership{
+ {'q', '~'}, // founder
+ {'a', '&'}, // protected
+ {'o', '@'}, // operator
+ {'h', '%'}, // halfop
+ {'v', '+'}, // voice
+}
+
+// memberships always sorted by descending membership rank
+type memberships []membership
+
+func (m *memberships) Add(availableMemberships []membership, newMembership membership) {
+ l := *m
+ i := 0
+ for _, availableMembership := range availableMemberships {
+ if i >= len(l) {
+ break
+ }
+ if l[i] == availableMembership {
+ if availableMembership == newMembership {
+ // we already have this membership
+ return
+ }
+ i++
+ continue
+ }
+ if availableMembership == newMembership {
+ break
+ }
+ }
+ // insert newMembership at i
+ l = append(l, membership{})
+ copy(l[i+1:], l[i:])
+ l[i] = newMembership
+ *m = l
+}
+
+func (m *memberships) Remove(oldMembership membership) {
+ l := *m
+ for i, currentMembership := range l {
+ if currentMembership == oldMembership {
+ *m = append(l[:i], l[i+1:]...)
+ return
+ }
+ }
+}
+
+func (m memberships) Format(dc *downstreamConn) string {
+ if !dc.caps["multi-prefix"] {
+ if len(m) == 0 {
+ return ""
+ }
+ return string(m[0].Prefix)
+ }
+ prefixes := make([]byte, len(m))
+ for i, membership := range m {
+ prefixes[i] = membership.Prefix
+ }
+ return string(prefixes)
+}
+
+func parseMessageParams(msg *irc.Message, out ...*string) error {
+ if len(msg.Params) < len(out) {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ for i := range out {
+ if out[i] != nil {
+ *out[i] = msg.Params[i]
+ }
+ }
+ return nil
+}
+
+func copyClientTags(tags irc.Tags) irc.Tags {
+ t := make(irc.Tags, len(tags))
+ for k, v := range tags {
+ if strings.HasPrefix(k, "+") {
+ t[k] = v
+ }
+ }
+ return t
+}
+
+type batch struct {
+ Type string
+ Params []string
+ Outer *batch // if not-nil, this batch is nested in Outer
+ Label string
+}
+
+func join(channels, keys []string) []*irc.Message {
+ // Put channels with a key first
+ js := joinSorter{channels, keys}
+ sort.Sort(&js)
+
+ // Two spaces because there are three words (JOIN, channels and keys)
+ maxLength := maxMessageLength - (len("JOIN") + 2)
+
+ var msgs []*irc.Message
+ var channelsBuf, keysBuf strings.Builder
+ for i, channel := range channels {
+ key := keys[i]
+
+ n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel)
+ if key != "" {
+ n += 1 + len(key)
+ }
+
+ if channelsBuf.Len() > 0 && n > maxLength {
+ // No room for the new channel in this message
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ channelsBuf.Reset()
+ keysBuf.Reset()
+ }
+
+ if channelsBuf.Len() > 0 {
+ channelsBuf.WriteByte(',')
+ }
+ channelsBuf.WriteString(channel)
+ if key != "" {
+ if keysBuf.Len() > 0 {
+ keysBuf.WriteByte(',')
+ }
+ keysBuf.WriteString(key)
+ }
+ }
+ if channelsBuf.Len() > 0 {
+ params := []string{channelsBuf.String()}
+ if keysBuf.Len() > 0 {
+ params = append(params, keysBuf.String())
+ }
+ msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
+ }
+
+ return msgs
+}
+
+func generateIsupport(prefix *irc.Prefix, nick string, tokens []string) []*irc.Message {
+ maxTokens := maxMessageParams - 2 // 2 reserved params: nick + text
+
+ var msgs []*irc.Message
+ for len(tokens) > 0 {
+ var msgTokens []string
+ if len(tokens) > maxTokens {
+ msgTokens = tokens[:maxTokens]
+ tokens = tokens[maxTokens:]
+ } else {
+ msgTokens = tokens
+ tokens = nil
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ISUPPORT,
+ Params: append(append([]string{nick}, msgTokens...), "are supported"),
+ })
+ }
+
+ return msgs
+}
+
+func generateMOTD(prefix *irc.Prefix, nick string, motd string) []*irc.Message {
+ var msgs []*irc.Message
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTDSTART,
+ Params: []string{nick, fmt.Sprintf("- Message of the Day -")},
+ })
+
+ for _, l := range strings.Split(motd, "\n") {
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_MOTD,
+ Params: []string{nick, l},
+ })
+ }
+
+ msgs = append(msgs, &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_ENDOFMOTD,
+ Params: []string{nick, "End of /MOTD command."},
+ })
+
+ return msgs
+}
+
+func generateMonitor(subcmd string, targets []string) []*irc.Message {
+ maxLength := maxMessageLength - len("MONITOR "+subcmd+" ")
+
+ var msgs []*irc.Message
+ var buf []string
+ n := 0
+ for _, target := range targets {
+ if n+len(target)+1 > maxLength {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ buf = buf[:0]
+ n = 0
+ }
+
+ buf = append(buf, target)
+ n += len(target) + 1
+ }
+
+ if len(buf) > 0 {
+ msgs = append(msgs, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{subcmd, strings.Join(buf, ",")},
+ })
+ }
+
+ return msgs
+}
+
+type joinSorter struct {
+ channels []string
+ keys []string
+}
+
+func (js *joinSorter) Len() int {
+ return len(js.channels)
+}
+
+func (js *joinSorter) Less(i, j int) bool {
+ if (js.keys[i] != "") != (js.keys[j] != "") {
+ // Only one of the channels has a key
+ return js.keys[i] != ""
+ }
+ return js.channels[i] < js.channels[j]
+}
+
+func (js *joinSorter) Swap(i, j int) {
+ js.channels[i], js.channels[j] = js.channels[j], js.channels[i]
+ js.keys[i], js.keys[j] = js.keys[j], js.keys[i]
+}
+
+// parseCTCPMessage parses a CTCP message. CTCP is defined in
+// https://tools.ietf.org/html/draft-oakley-irc-ctcp-02
+func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
+ if (msg.Command != "PRIVMSG" && msg.Command != "NOTICE") || len(msg.Params) < 2 {
+ return "", "", false
+ }
+ text := msg.Params[1]
+
+ if !strings.HasPrefix(text, "\x01") {
+ return "", "", false
+ }
+ text = strings.Trim(text, "\x01")
+
+ words := strings.SplitN(text, " ", 2)
+ cmd = strings.ToUpper(words[0])
+ if len(words) > 1 {
+ params = words[1]
+ }
+
+ return cmd, params, true
+}
+
+type casemapping func(string) string
+
+func casemapNone(name string) string {
+ return name
+}
+
+// CasemapASCII of name is the canonical representation of name according to the
+// ascii casemapping.
+func casemapASCII(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459 of name is the canonical representation of name according to the
+// rfc1459 casemapping.
+func casemapRFC1459(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ } else if r == '~' {
+ nameBytes[i] = '^'
+ }
+ }
+ return string(nameBytes)
+}
+
+// casemapRFC1459Strict of name is the canonical representation of name
+// according to the rfc1459-strict casemapping.
+func casemapRFC1459Strict(name string) string {
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if 'A' <= r && r <= 'Z' {
+ nameBytes[i] = r + 'a' - 'A'
+ } else if r == '{' {
+ nameBytes[i] = '['
+ } else if r == '}' {
+ nameBytes[i] = ']'
+ } else if r == '\\' {
+ nameBytes[i] = '|'
+ }
+ }
+ return string(nameBytes)
+}
+
+func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
+ switch tokenValue {
+ case "ascii":
+ casemap = casemapASCII
+ case "rfc1459":
+ casemap = casemapRFC1459
+ case "rfc1459-strict":
+ casemap = casemapRFC1459Strict
+ default:
+ return nil, false
+ }
+ return casemap, true
+}
+
+func partialCasemap(higher casemapping, name string) string {
+ nameFullyCM := []byte(higher(name))
+ nameBytes := []byte(name)
+ for i, r := range nameBytes {
+ if !('A' <= r && r <= 'Z') && !('a' <= r && r <= 'z') {
+ nameBytes[i] = nameFullyCM[i]
+ }
+ }
+ return string(nameBytes)
+}
+
+type casemapMap struct {
+ innerMap map[string]casemapEntry
+ casemap casemapping
+}
+
+type casemapEntry struct {
+ originalKey string
+ value interface{}
+}
+
+func newCasemapMap(size int) casemapMap {
+ return casemapMap{
+ innerMap: make(map[string]casemapEntry, size),
+ casemap: casemapNone,
+ }
+}
+
+func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return "", false
+ }
+ return entry.originalKey, true
+}
+
+func (cm *casemapMap) Has(name string) bool {
+ _, ok := cm.innerMap[cm.casemap(name)]
+ return ok
+}
+
+func (cm *casemapMap) Len() int {
+ return len(cm.innerMap)
+}
+
+func (cm *casemapMap) SetValue(name string, value interface{}) {
+ nameCM := cm.casemap(name)
+ entry, ok := cm.innerMap[nameCM]
+ if !ok {
+ cm.innerMap[nameCM] = casemapEntry{
+ originalKey: name,
+ value: value,
+ }
+ return
+ }
+ entry.value = value
+ cm.innerMap[nameCM] = entry
+}
+
+func (cm *casemapMap) Delete(name string) {
+ delete(cm.innerMap, cm.casemap(name))
+}
+
+func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
+ cm.casemap = newCasemap
+ newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
+ for _, entry := range cm.innerMap {
+ newInnerMap[cm.casemap(entry.originalKey)] = entry
+ }
+ cm.innerMap = newInnerMap
+}
+
+type upstreamChannelCasemapMap struct{ casemapMap }
+
+func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*upstreamChannel)
+}
+
+type channelCasemapMap struct{ casemapMap }
+
+func (cm *channelCasemapMap) Value(name string) *Channel {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*Channel)
+}
+
+type membershipsCasemapMap struct{ casemapMap }
+
+func (cm *membershipsCasemapMap) Value(name string) *memberships {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(*memberships)
+}
+
+type deliveredCasemapMap struct{ casemapMap }
+
+func (cm *deliveredCasemapMap) Value(name string) deliveredClientMap {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return nil
+ }
+ return entry.value.(deliveredClientMap)
+}
+
+type monitorCasemapMap struct{ casemapMap }
+
+func (cm *monitorCasemapMap) Value(name string) (online bool) {
+ entry, ok := cm.innerMap[cm.casemap(name)]
+ if !ok {
+ return false
+ }
+ return entry.value.(bool)
+}
+
+func isWordBoundary(r rune) bool {
+ switch r {
+ case '-', '_', '|': // inspired from weechat.look.highlight_regex
+ return false
+ default:
+ return !unicode.IsLetter(r) && !unicode.IsNumber(r)
+ }
+}
+
+func isHighlight(text, nick string) bool {
+ for {
+ i := strings.Index(text, nick)
+ if i < 0 {
+ return false
+ }
+
+ left, _ := utf8.DecodeLastRuneInString(text[:i])
+ right, _ := utf8.DecodeRuneInString(text[i+len(nick):])
+ if isWordBoundary(left) && isWordBoundary(right) {
+ return true
+ }
+
+ text = text[i+len(nick):]
+ }
+}
+
+// parseChatHistoryBound parses the given CHATHISTORY parameter as a bound.
+// The zero time is returned on error.
+func parseChatHistoryBound(param string) time.Time {
+ parts := strings.SplitN(param, "=", 2)
+ if len(parts) != 2 {
+ return time.Time{}
+ }
+ switch parts[0] {
+ case "timestamp":
+ timestamp, err := time.Parse(serverTimeLayout, parts[1])
+ if err != nil {
+ return time.Time{}
+ }
+ return timestamp
+ default:
+ return time.Time{}
+ }
+}
+
+// whoxFields is the list of all WHOX field letters, by order of appearance in
+// RPL_WHOSPCRPL messages.
+var whoxFields = []byte("tcuihsnfdlaor")
+
+type whoxInfo struct {
+ Token string
+ Username string
+ Hostname string
+ Server string
+ Nickname string
+ Flags string
+ Account string
+ Realname string
+}
+
+func (info *whoxInfo) get(field byte) string {
+ switch field {
+ case 't':
+ return info.Token
+ case 'c':
+ return "*"
+ case 'u':
+ return info.Username
+ case 'i':
+ return "255.255.255.255"
+ case 'h':
+ return info.Hostname
+ case 's':
+ return info.Server
+ case 'n':
+ return info.Nickname
+ case 'f':
+ return info.Flags
+ case 'd':
+ return "0"
+ case 'l': // idle time
+ return "0"
+ case 'a':
+ account := "0" // WHOX uses "0" to mean "no account"
+ if info.Account != "" && info.Account != "*" {
+ account = info.Account
+ }
+ return account
+ case 'o':
+ return "0"
+ case 'r':
+ return info.Realname
+ }
+ return ""
+}
+
+func generateWHOXReply(prefix *irc.Prefix, nick, fields string, info *whoxInfo) *irc.Message {
+ if fields == "" {
+ return &irc.Message{
+ Prefix: prefix,
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{nick, "*", info.Username, info.Hostname, info.Server, info.Nickname, info.Flags, "0 " + info.Realname},
+ }
+ }
+
+ fieldSet := make(map[byte]bool)
+ for i := 0; i < len(fields); i++ {
+ fieldSet[fields[i]] = true
+ }
+
+ var values []string
+ for _, field := range whoxFields {
+ if !fieldSet[field] {
+ continue
+ }
+ values = append(values, info.get(field))
+ }
+
+ return &irc.Message{
+ Prefix: prefix,
+ Command: rpl_whospcrpl,
+ Params: append([]string{nick}, values...),
+ }
+}
+
+var isupportEncoder = strings.NewReplacer(" ", "\\x20", "\\", "\\x5C")
+
+func encodeISUPPORT(s string) string {
+ return isupportEncoder.Replace(s)
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func TestIsHighlight(t *testing.T) {
+ nick := "SojuUser"
+ testCases := []struct {
+ name string
+ text string
+ hl bool
+ }{
+ {"noContains", "hi there Soju User!", false},
+ {"middle", "hi there SojuUser!", true},
+ {"start", "SojuUser: how are you doing?", true},
+ {"end", "maybe ask SojuUser", true},
+ {"inWord", "but OtherSojuUserSan is a different nick", false},
+ {"startWord", "and OtherSojuUser is another different nick", false},
+ {"endWord", "and SojuUserSan is yet a different nick", false},
+ {"underscore", "and SojuUser_san has nothing to do with me", false},
+ {"zeroWidthSpace", "writing S\u200BojuUser shouldn't trigger a highlight", false},
+ }
+
+ for _, tc := range testCases {
+ tc := tc // capture range variable
+ t.Run(tc.name, func(t *testing.T) {
+ hl := isHighlight(tc.text, nick)
+ if hl != tc.hl {
+ t.Errorf("isHighlight(%q, %q) = %v, but want %v", tc.text, nick, hl, tc.hl)
+ }
+ })
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+// messageStore is a per-user store for IRC messages.
+type messageStore interface {
+ Close() error
+ // LastMsgID queries the last message ID for the given network, entity and
+ // date. The message ID returned may not refer to a valid message, but can be
+ // used in history queries.
+ LastMsgID(network *Network, entity string, t time.Time) (string, error)
+ // LoadLatestID queries the latest non-event messages for the given network,
+ // entity and date, up to a count of limit messages, sorted from oldest to newest.
+ LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
+ Append(network *Network, entity string, msg *irc.Message) (id string, err error)
+}
+
+type chatHistoryTarget struct {
+ Name string
+ LatestMessage time.Time
+}
+
+// chatHistoryMessageStore is a message store that supports chat history
+// operations.
+type chatHistoryMessageStore interface {
+ messageStore
+
+ // ListTargets lists channels and nicknames by time of the latest message.
+ // It returns up to limit targets, starting from start and ending on end,
+ // both excluded. end may be before or after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
+ // LoadBeforeTime loads up to limit messages before start down to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is before start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+ // LoadBeforeTime loads up to limit messages after start up to end. The
+ // returned messages must be between and excluding the provided bounds.
+ // end is after start.
+ // If events is false, only PRIVMSG/NOTICE messages are considered.
+ LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
+}
+
+type msgIDType uint
+
+const (
+ msgIDNone msgIDType = iota
+ msgIDMemory
+ msgIDFS
+)
+
+const msgIDVersion uint = 0
+
+type msgIDHeader struct {
+ Version uint
+ Network bare.Int
+ Target string
+ Type msgIDType
+}
+
+type msgIDBody interface {
+ msgIDType() msgIDType
+}
+
+func formatMsgID(netID int64, target string, body msgIDBody) string {
+ var buf bytes.Buffer
+ w := bare.NewWriter(&buf)
+
+ header := msgIDHeader{
+ Version: msgIDVersion,
+ Network: bare.Int(netID),
+ Target: target,
+ Type: body.msgIDType(),
+ }
+ if err := bare.MarshalWriter(w, &header); err != nil {
+ panic(err)
+ }
+ if err := bare.MarshalWriter(w, body); err != nil {
+ panic(err)
+ }
+ return base64.RawURLEncoding.EncodeToString(buf.Bytes())
+}
+
+func parseMsgID(s string, body msgIDBody) (netID int64, target string, err error) {
+ b, err := base64.RawURLEncoding.DecodeString(s)
+ if err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ r := bare.NewReader(bytes.NewReader(b))
+
+ var header msgIDHeader
+ if err := bare.UnmarshalBareReader(r, &header); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+
+ if header.Version != msgIDVersion {
+ return 0, "", fmt.Errorf("invalid internal message ID: got version %v, want %v", header.Version, msgIDVersion)
+ }
+
+ if body != nil {
+ typ := body.msgIDType()
+ if header.Type != typ {
+ return 0, "", fmt.Errorf("invalid internal message ID: got type %v, want %v", header.Type, typ)
+ }
+
+ if err := bare.UnmarshalBareReader(r, body); err != nil {
+ return 0, "", fmt.Errorf("invalid internal message ID: %v", err)
+ }
+ }
+
+ return int64(header.Network), header.Target, nil
+}
--- /dev/null
+package suika
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ fsMessageStoreMaxFiles = 20
+ fsMessageStoreMaxTries = 100
+)
+
+func escapeFilename(unsafe string) (safe string) {
+ if unsafe == "." {
+ return "-"
+ } else if unsafe == ".." {
+ return "--"
+ } else {
+ return strings.NewReplacer("/", "-", "\\", "-").Replace(unsafe)
+ }
+}
+
+type date struct {
+ Year, Month, Day int
+}
+
+func newDate(t time.Time) date {
+ year, month, day := t.Date()
+ return date{year, int(month), day}
+}
+
+func (d date) Time() time.Time {
+ return time.Date(d.Year, time.Month(d.Month), d.Day, 0, 0, 0, 0, time.Local)
+}
+
+type fsMsgID struct {
+ Date date
+ Offset bare.Int
+}
+
+func (fsMsgID) msgIDType() msgIDType {
+ return msgIDFS
+}
+
+func parseFSMsgID(s string) (netID int64, entity string, t time.Time, offset int64, err error) {
+ var id fsMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", time.Time{}, 0, err
+ }
+ return netID, entity, id.Date.Time(), int64(id.Offset), nil
+}
+
+func formatFSMsgID(netID int64, entity string, t time.Time, offset int64) string {
+ id := fsMsgID{
+ Date: newDate(t),
+ Offset: bare.Int(offset),
+ }
+ return formatMsgID(netID, entity, &id)
+}
+
+type fsMessageStoreFile struct {
+ *os.File
+ lastUse time.Time
+}
+
+// fsMessageStore is a per-user on-disk store for IRC messages.
+//
+// It mimicks the ZNC log layout and format. See the ZNC source:
+// https://github.com/znc/znc/blob/master/modules/log.cpp
+type fsMessageStore struct {
+ root string
+ user *User
+
+ // Write-only files used by Append
+ files map[string]*fsMessageStoreFile // indexed by entity
+}
+
+var _ messageStore = (*fsMessageStore)(nil)
+var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
+
+func newFSMessageStore(root string, user *User) *fsMessageStore {
+ return &fsMessageStore{
+ root: filepath.Join(root, escapeFilename(user.Username)),
+ user: user,
+ files: make(map[string]*fsMessageStoreFile),
+ }
+}
+
+func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
+ year, month, day := t.Date()
+ filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
+ return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
+}
+
+// nextMsgID queries the message ID for the next message to be written to f.
+func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
+ offset, err := f.Seek(0, io.SeekEnd)
+ if err != nil {
+ return "", fmt.Errorf("failed to query next FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, offset), nil
+}
+
+func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ p := ms.logPath(network, entity, t)
+ fi, err := os.Stat(p)
+ if os.IsNotExist(err) {
+ return formatFSMsgID(network.ID, entity, t, -1), nil
+ } else if err != nil {
+ return "", fmt.Errorf("failed to query last FS message ID: %v", err)
+ }
+ return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
+}
+
+func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ s := formatMessage(msg)
+ if s == "" {
+ return "", nil
+ }
+
+ var t time.Time
+ if tag, ok := msg.Tags["time"]; ok {
+ var err error
+ t, err = time.Parse(serverTimeLayout, string(tag))
+ if err != nil {
+ return "", fmt.Errorf("failed to parse message time tag: %v", err)
+ }
+ t = t.In(time.Local)
+ } else {
+ t = time.Now()
+ }
+
+ f := ms.files[entity]
+
+ // TODO: handle non-monotonic clock behaviour
+ path := ms.logPath(network, entity, t)
+ if f == nil || f.Name() != path {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0750); err != nil {
+ return "", fmt.Errorf("failed to create message logs directory %q: %v", dir, err)
+ }
+
+ ff, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0640)
+ if err != nil {
+ return "", fmt.Errorf("failed to open message log file %q: %v", path, err)
+ }
+
+ if f != nil {
+ f.Close()
+ }
+ f = &fsMessageStoreFile{File: ff}
+ ms.files[entity] = f
+ }
+
+ f.lastUse = time.Now()
+
+ if len(ms.files) > fsMessageStoreMaxFiles {
+ entities := make([]string, 0, len(ms.files))
+ for name := range ms.files {
+ entities = append(entities, name)
+ }
+ sort.Slice(entities, func(i, j int) bool {
+ a, b := entities[i], entities[j]
+ return ms.files[a].lastUse.Before(ms.files[b].lastUse)
+ })
+ entities = entities[0 : len(entities)-fsMessageStoreMaxFiles]
+ for _, name := range entities {
+ ms.files[name].Close()
+ delete(ms.files, name)
+ }
+ }
+
+ msgID, err := nextFSMsgID(network, entity, t, f.File)
+ if err != nil {
+ return "", fmt.Errorf("failed to generate message ID: %v", err)
+ }
+
+ _, err = fmt.Fprintf(f, "[%02d:%02d:%02d] %s\n", t.Hour(), t.Minute(), t.Second(), s)
+ if err != nil {
+ return "", fmt.Errorf("failed to log message to %q: %v", f.Name(), err)
+ }
+
+ return msgID, nil
+}
+
+func (ms *fsMessageStore) Close() error {
+ var closeErr error
+ for _, f := range ms.files {
+ if err := f.Close(); err != nil {
+ closeErr = fmt.Errorf("failed to close message store: %v", err)
+ }
+ }
+ return closeErr
+}
+
+// formatMessage formats a message log line. It assumes a well-formed IRC
+// message.
+func formatMessage(msg *irc.Message) string {
+ switch strings.ToUpper(msg.Command) {
+ case "NICK":
+ return fmt.Sprintf("*** %s is now known as %s", msg.Prefix.Name, msg.Params[0])
+ case "JOIN":
+ return fmt.Sprintf("*** Joins: %s (%s@%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host)
+ case "PART":
+ var reason string
+ if len(msg.Params) > 1 {
+ reason = msg.Params[1]
+ }
+ return fmt.Sprintf("*** Parts: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "KICK":
+ nick := msg.Params[1]
+ var reason string
+ if len(msg.Params) > 2 {
+ reason = msg.Params[2]
+ }
+ return fmt.Sprintf("*** %s was kicked by %s (%s)", nick, msg.Prefix.Name, reason)
+ case "QUIT":
+ var reason string
+ if len(msg.Params) > 0 {
+ reason = msg.Params[0]
+ }
+ return fmt.Sprintf("*** Quits: %s (%s@%s) (%s)", msg.Prefix.Name, msg.Prefix.User, msg.Prefix.Host, reason)
+ case "TOPIC":
+ var topic string
+ if len(msg.Params) > 1 {
+ topic = msg.Params[1]
+ }
+ return fmt.Sprintf("*** %s changes topic to '%s'", msg.Prefix.Name, topic)
+ case "MODE":
+ return fmt.Sprintf("*** %s sets mode: %s", msg.Prefix.Name, strings.Join(msg.Params[1:], " "))
+ case "NOTICE":
+ return fmt.Sprintf("-%s- %s", msg.Prefix.Name, msg.Params[1])
+ case "PRIVMSG":
+ if cmd, params, ok := parseCTCPMessage(msg); ok && cmd == "ACTION" {
+ return fmt.Sprintf("* %s %s", msg.Prefix.Name, params)
+ } else {
+ return fmt.Sprintf("<%s> %s", msg.Prefix.Name, msg.Params[1])
+ }
+ default:
+ return ""
+ }
+}
+
+func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
+ var hour, minute, second int
+ _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
+ if err != nil {
+ return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err)
+ }
+ line = line[11:]
+
+ var cmd string
+ var prefix *irc.Prefix
+ var params []string
+ if events && strings.HasPrefix(line, "*** ") {
+ parts := strings.SplitN(line[4:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ switch parts[0] {
+ case "Joins:", "Parts:", "Quits:":
+ args := strings.SplitN(parts[1], " ", 3)
+ if len(args) < 2 {
+ return nil, time.Time{}, nil
+ }
+ nick := args[0]
+ mask := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ maskParts := strings.SplitN(mask, "@", 2)
+ if len(maskParts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ prefix = &irc.Prefix{
+ Name: nick,
+ User: maskParts[0],
+ Host: maskParts[1],
+ }
+ var reason string
+ if len(args) > 2 {
+ reason = strings.TrimSuffix(strings.TrimPrefix(args[2], "("), ")")
+ }
+ switch parts[0] {
+ case "Joins:":
+ cmd = "JOIN"
+ params = []string{entity}
+ case "Parts:":
+ cmd = "PART"
+ if reason != "" {
+ params = []string{entity, reason}
+ } else {
+ params = []string{entity}
+ }
+ case "Quits:":
+ cmd = "QUIT"
+ if reason != "" {
+ params = []string{reason}
+ }
+ }
+ default:
+ nick := parts[0]
+ rem := parts[1]
+ if r := strings.TrimPrefix(rem, "is now known as "); r != rem {
+ cmd = "NICK"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = []string{r}
+ } else if r := strings.TrimPrefix(rem, "was kicked by "); r != rem {
+ args := strings.SplitN(r, " ", 2)
+ if len(args) != 2 {
+ return nil, time.Time{}, nil
+ }
+ cmd = "KICK"
+ prefix = &irc.Prefix{
+ Name: args[0],
+ }
+ reason := strings.TrimSuffix(strings.TrimPrefix(args[1], "("), ")")
+ params = []string{entity, nick}
+ if reason != "" {
+ params = append(params, reason)
+ }
+ } else if r := strings.TrimPrefix(rem, "changes topic to "); r != rem {
+ cmd = "TOPIC"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ topic := strings.TrimSuffix(strings.TrimPrefix(r, "'"), "'")
+ params = []string{entity, topic}
+ } else if r := strings.TrimPrefix(rem, "sets mode: "); r != rem {
+ cmd = "MODE"
+ prefix = &irc.Prefix{
+ Name: nick,
+ }
+ params = append([]string{entity}, strings.Split(r, " ")...)
+ } else {
+ return nil, time.Time{}, nil
+ }
+ }
+ } else {
+ var sender, text string
+ if strings.HasPrefix(line, "<") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[1:], "> ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "-") {
+ cmd = "NOTICE"
+ parts := strings.SplitN(line[1:], "- ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], parts[1]
+ } else if strings.HasPrefix(line, "* ") {
+ cmd = "PRIVMSG"
+ parts := strings.SplitN(line[2:], " ", 2)
+ if len(parts) != 2 {
+ return nil, time.Time{}, nil
+ }
+ sender, text = parts[0], "\x01ACTION "+parts[1]+"\x01"
+ } else {
+ return nil, time.Time{}, nil
+ }
+
+ prefix = &irc.Prefix{Name: sender}
+ if entity == sender {
+ // This is a direct message from a user to us. We don't store own
+ // our nickname in the logs, so grab it from the network settings.
+ // Not very accurate since this may not match our nick at the time
+ // the message was received, but we can't do a lot better.
+ entity = GetNick(ms.user, network)
+ }
+ params = []string{entity, text}
+ }
+
+ year, month, day := ref.Date()
+ t := time.Date(year, month, day, hour, minute, second, 0, time.Local)
+
+ msg := &irc.Message{
+ Tags: map[string]irc.TagValue{
+ "time": irc.TagValue(formatServerTime(t)),
+ },
+ Prefix: prefix,
+ Command: cmd,
+ Params: params,
+ }
+ return msg, t, nil
+}
+
+func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages before ref: %v", err)
+ }
+ defer f.Close()
+
+ historyRing := make([]*irc.Message, limit)
+ cur := 0
+
+ sc := bufio.NewScanner(f)
+
+ if afterOffset >= 0 {
+ if _, err := f.Seek(afterOffset, io.SeekStart); err != nil {
+ return nil, nil
+ }
+ sc.Scan() // skip till next newline
+ }
+
+ for sc.Scan() {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(end) {
+ continue
+ } else if !t.Before(ref) {
+ break
+ }
+
+ historyRing[cur%limit] = msg
+ cur++
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages before ref: scanner error: %v", sc.Err())
+ }
+
+ n := limit
+ if cur < limit {
+ n = cur
+ }
+ start := (cur - n + limit) % limit
+
+ if start+n <= limit { // ring doesnt wrap
+ return historyRing[start : start+n], nil
+ } else { // ring wraps
+ history := make([]*irc.Message, n)
+ r := copy(history, historyRing[start:])
+ copy(history[r:], historyRing[:n-r])
+ return history, nil
+ }
+}
+
+func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int) ([]*irc.Message, error) {
+ path := ms.logPath(network, entity, ref)
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("failed to parse messages after ref: %v", err)
+ }
+ defer f.Close()
+
+ var history []*irc.Message
+ sc := bufio.NewScanner(f)
+ for sc.Scan() && len(history) < limit {
+ msg, t, err := ms.parseMessage(sc.Text(), network, entity, ref, events)
+ if err != nil {
+ return nil, err
+ } else if msg == nil || !t.After(ref) {
+ continue
+ } else if !t.Before(end) {
+ break
+ }
+
+ history = append(history, msg)
+ }
+ if sc.Err() != nil {
+ return nil, fmt.Errorf("failed to parse messages after ref: scanner error: %v", sc.Err())
+ }
+
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ history := make([]*irc.Message, limit)
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && end.Before(start) {
+ buf, err := ms.parseMessagesBefore(network, entity, start, end, events, remaining, -1)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day, 0, 0, 0, 0, start.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ var history []*irc.Message
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && start.Before(end) {
+ buf, err := ms.parseMessagesAfter(network, entity, start, end, events, remaining)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ history = append(history, buf...)
+ remaining -= len(buf)
+ year, month, day := start.Date()
+ start = time.Date(year, month, day+1, 0, 0, 0, 0, start.Location())
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+ return history, nil
+}
+
+func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ var afterTime time.Time
+ var afterOffset int64
+ if id != "" {
+ var idNet int64
+ var idEntity string
+ var err error
+ idNet, idEntity, afterTime, afterOffset, err = parseFSMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+ if idNet != network.ID || idEntity != entity {
+ return nil, fmt.Errorf("cannot find message ID: message ID doesn't match network/entity")
+ }
+ }
+
+ history := make([]*irc.Message, limit)
+ t := time.Now()
+ remaining := limit
+ tries := 0
+ for remaining > 0 && tries < fsMessageStoreMaxTries && !truncateDay(t).Before(afterTime) {
+ var offset int64 = -1
+ if afterOffset >= 0 && truncateDay(t).Equal(afterTime) {
+ offset = afterOffset
+ }
+
+ buf, err := ms.parseMessagesBefore(network, entity, t, time.Time{}, false, remaining, offset)
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ tries++
+ } else {
+ tries = 0
+ }
+ copy(history[remaining-len(buf):], buf)
+ remaining -= len(buf)
+ year, month, day := t.Date()
+ t = time.Date(year, month, day, 0, 0, 0, 0, t.Location()).Add(-1)
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ return history[remaining:], nil
+}
+
+func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
+ start = start.In(time.Local)
+ end = end.In(time.Local)
+ rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
+ root, err := os.Open(rootPath)
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+
+ // The returned targets are escaped, and there is no way to un-escape
+ // TODO: switch to ReadDir (Go 1.16+)
+ targetNames, err := root.Readdirnames(0)
+ root.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ var targets []chatHistoryTarget
+ for _, target := range targetNames {
+ // target is already escaped here
+ targetPath := filepath.Join(rootPath, target)
+ targetDir, err := os.Open(targetPath)
+ if err != nil {
+ return nil, err
+ }
+
+ entries, err := targetDir.Readdir(0)
+ targetDir.Close()
+ if err != nil {
+ return nil, err
+ }
+
+ // We use mtime here, which may give imprecise or incorrect results
+ var t time.Time
+ for _, entry := range entries {
+ if entry.ModTime().After(t) {
+ t = entry.ModTime()
+ }
+ }
+
+ // The timestamps we get from logs have second granularity
+ t = truncateSecond(t)
+
+ // Filter out targets that don't fullfil the time bounds
+ if !isTimeBetween(t, start, end) {
+ continue
+ }
+
+ targets = append(targets, chatHistoryTarget{
+ Name: target,
+ LatestMessage: t,
+ })
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ // Sort targets by latest message time, backwards or forwards depending on
+ // the order of the time bounds
+ sort.Slice(targets, func(i, j int) bool {
+ t1, t2 := targets[i].LatestMessage, targets[j].LatestMessage
+ if start.Before(end) {
+ return t1.Before(t2)
+ } else {
+ return !t1.Before(t2)
+ }
+ })
+
+ // Truncate the result if necessary
+ if len(targets) > limit {
+ targets = targets[:limit]
+ }
+
+ return targets, nil
+}
+
+func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
+ oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
+ newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
+ // Avoid loosing data by overwriting an existing directory
+ if _, err := os.Stat(newDir); err == nil {
+ return fmt.Errorf("destination %q already exists", newDir)
+ }
+ return os.Rename(oldDir, newDir)
+}
+
+func truncateDay(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, 0, 0, 0, 0, t.Location())
+}
+
+func truncateSecond(t time.Time) time.Time {
+ year, month, day := t.Date()
+ return time.Date(year, month, day, t.Hour(), t.Minute(), t.Second(), 0, t.Location())
+}
+
+func isTimeBetween(t, start, end time.Time) bool {
+ if end.Before(start) {
+ end, start = start, end
+ }
+ return start.Before(t) && t.Before(end)
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "git.sr.ht/~sircmpwn/go-bare"
+ "gopkg.in/irc.v3"
+)
+
+const messageRingBufferCap = 4096
+
+type memoryMsgID struct {
+ Seq bare.Uint
+}
+
+func (memoryMsgID) msgIDType() msgIDType {
+ return msgIDMemory
+}
+
+func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
+ var id memoryMsgID
+ netID, entity, err = parseMsgID(s, &id)
+ if err != nil {
+ return 0, "", 0, err
+ }
+ return netID, entity, uint64(id.Seq), nil
+}
+
+func formatMemoryMsgID(netID int64, entity string, seq uint64) string {
+ id := memoryMsgID{bare.Uint(seq)}
+ return formatMsgID(netID, entity, &id)
+}
+
+type ringBufferKey struct {
+ networkID int64
+ entity string
+}
+
+type memoryMessageStore struct {
+ buffers map[ringBufferKey]*messageRingBuffer
+}
+
+var _ messageStore = (*memoryMessageStore)(nil)
+
+func newMemoryMessageStore() *memoryMessageStore {
+ return &memoryMessageStore{
+ buffers: make(map[ringBufferKey]*messageRingBuffer),
+ }
+}
+
+func (ms *memoryMessageStore) Close() error {
+ ms.buffers = nil
+ return nil
+}
+
+func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ return rb
+ }
+ rb := newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ return rb
+}
+
+func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
+ var seq uint64
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ if rb, ok := ms.buffers[k]; ok {
+ seq = rb.cur
+ }
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
+ switch msg.Command {
+ case "PRIVMSG", "NOTICE":
+ // Only append these messages, because LoadLatestID shouldn't return
+ // other kinds of message.
+ default:
+ return "", nil
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ rb = newMessageRingBuffer(messageRingBufferCap)
+ ms.buffers[k] = rb
+ }
+
+ seq := rb.Append(msg)
+ return formatMemoryMsgID(network.ID, entity, seq), nil
+}
+
+func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
+ _, _, seq, err := parseMemoryMsgID(id)
+ if err != nil {
+ return nil, err
+ }
+
+ k := ringBufferKey{networkID: network.ID, entity: entity}
+ rb, ok := ms.buffers[k]
+ if !ok {
+ return nil, nil
+ }
+
+ return rb.LoadLatestSeq(seq, limit)
+}
+
+type messageRingBuffer struct {
+ buf []*irc.Message
+ cur uint64
+}
+
+func newMessageRingBuffer(capacity int) *messageRingBuffer {
+ return &messageRingBuffer{
+ buf: make([]*irc.Message, capacity),
+ cur: 1,
+ }
+}
+
+func (rb *messageRingBuffer) cap() uint64 {
+ return uint64(len(rb.buf))
+}
+
+func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 {
+ seq := rb.cur
+ i := int(seq % rb.cap())
+ rb.buf[i] = msg
+ rb.cur++
+ return seq
+}
+
+func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) {
+ if seq > rb.cur {
+ return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur)
+ } else if seq == rb.cur {
+ return nil, nil
+ }
+
+ // The query excludes the message with the sequence number seq
+ diff := rb.cur - seq - 1
+ if diff > rb.cap() {
+ // We dropped diff - cap entries
+ diff = rb.cap()
+ }
+ if int(diff) > limit {
+ diff = uint64(limit)
+ }
+
+ l := make([]*irc.Message, int(diff))
+ for i := 0; i < int(diff); i++ {
+ j := int((rb.cur - diff + uint64(i)) % rb.cap())
+ l[i] = rb.buf[j]
+ }
+
+ return l, nil
+}
--- /dev/null
+//go:build !go1.16
+// +build !go1.16
+
+package suika
+
+import (
+ "strings"
+)
+
+func isErrClosed(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "use of closed network connection")
+}
--- /dev/null
+//go:build go1.16
+// +build go1.16
+
+package suika
+
+import (
+ "errors"
+ "net"
+)
+
+func isErrClosed(err error) bool {
+ return errors.Is(err, net.ErrClosed)
+}
--- /dev/null
+package suika
+
+import (
+ "math/rand"
+ "time"
+)
+
+// backoffer implements a simple exponential backoff.
+type backoffer struct {
+ min, max, jitter time.Duration
+ n int64
+}
+
+func newBackoffer(min, max, jitter time.Duration) *backoffer {
+ return &backoffer{min: min, max: max, jitter: jitter}
+}
+
+func (b *backoffer) Reset() {
+ b.n = 0
+}
+
+func (b *backoffer) Next() time.Duration {
+ if b.n == 0 {
+ b.n = 1
+ return 0
+ }
+
+ d := time.Duration(b.n) * b.min
+ if d > b.max {
+ d = b.max
+ } else {
+ b.n *= 2
+ }
+
+ if b.jitter != 0 {
+ d += time.Duration(rand.Int63n(int64(b.jitter)))
+ }
+
+ return d
+}
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+desc="A drunk IRC bouncer"
+rcvar="suika_enable"
+
+: ${suika_user="ircd"}
+
+command="%%PREFIX%%/bin/suika"
+pidfile="/var/run/suika.pid"
+required_files="%%PREFIX%%/etc/suika/config"
+
+start_cmd="suika_start"
+
+suika_start() {
+ /usr/sbin/daemon -f -p ${pidfile} -u ${suika_user} -l daemon ${command} --config ${required_files}
+}
+
+load_rc_config "$name"
+run_rc_command "$1"
--- /dev/null
+# $TheSupernovaDuo$
+cmd: %%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+user: ircd
--- /dev/null
+#!/bin/sh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+# PROVIDE: suika
+# REQUIRE: DAEMON
+# BEFORE: LOGIN
+# KEYWORD: shutdown
+
+. /etc/rc.subr
+
+name="suika"
+rcvar="${name}"
+command="%%PREFIX/bin/${name}"
+command_args="--config %%PREFIX%%/etc/suika/config"
+pidfile="/var/run/${name}.pid"
+start_cmd="${name}_start"
+
+suika_start() {
+ printf "Starting %s..." "${name}"
+ ${command} ${command_args}
+ pgrep -n ${name} > ${pidfile}
+}
+
+load_rc_config ${name}
+run_rc_command "$1"
+
+
--- /dev/null
+#!/bin/ksh
+# $TheSupernovaDuo$
+# vim: ft=sh
+
+daemon="%%PREFIX%%/bin/suika"
+daemon_args="--config %%PREFIX%%/etc/suika/config"
+
+. /etc/rc.d/rc.subr
+
+rc_bg=YES
+
+rc_cmd "$1"
--- /dev/null
+# $TheSupernovaDuo$
+# vim: ft=confini
+[Unit]
+Description=A drunk IRC bouncer
+After=network.target
+Wants=network.target
+StartLimitBurst=5
+StartLimitIntervalSec=1
+[Service]
+Type=simple
+Restart=on-abnormal
+RestartSec=1
+User=suika
+ExecStart=%%PREFIX%%/bin/suika --config %%PREFIX%%/etc/suika/config
+[Install]
+WantedBy=multi-user.target
--- /dev/null
+package suika
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "runtime/debug"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+// TODO: make configurable
+var (
+ retryConnectMinDelay = time.Minute
+ retryConnectMaxDelay = 10 * time.Minute
+ retryConnectJitter = time.Minute
+ connectTimeout = 15 * time.Second
+ writeTimeout = 10 * time.Second
+ upstreamMessageDelay = 2 * time.Second
+ upstreamMessageBurst = 10
+ backlogTimeout = 10 * time.Second
+ handleDownstreamMessageTimeout = 10 * time.Second
+ downstreamRegisterTimeout = 30 * time.Second
+ chatHistoryLimit = 1000
+ backlogLimit = 4000
+)
+
+type Logger interface {
+ Printf(format string, v ...interface{})
+ Debugf(format string, v ...interface{})
+}
+
+type logger struct {
+ *log.Logger
+ debug bool
+}
+
+func (l logger) Debugf(format string, v ...interface{}) {
+ if !l.debug {
+ return
+ }
+ l.Logger.Printf(format, v...)
+}
+
+func NewLogger(out io.Writer, debug bool) Logger {
+ return logger{
+ Logger: log.New(log.Writer(), "", log.LstdFlags),
+ debug: debug,
+ }
+}
+
+type prefixLogger struct {
+ logger Logger
+ prefix string
+}
+
+var _ Logger = (*prefixLogger)(nil)
+
+func (l *prefixLogger) Printf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Printf("%v"+format, v...)
+}
+
+func (l *prefixLogger) Debugf(format string, v ...interface{}) {
+ v = append([]interface{}{l.prefix}, v...)
+ l.logger.Debugf("%v"+format, v...)
+}
+
+type int64Gauge struct {
+ v int64 // atomic
+}
+
+func (g *int64Gauge) Add(delta int64) {
+ atomic.AddInt64(&g.v, delta)
+}
+
+func (g *int64Gauge) Value() int64 {
+ return atomic.LoadInt64(&g.v)
+}
+
+func (g *int64Gauge) Float64() float64 {
+ return float64(g.Value())
+}
+
+type retryListener struct {
+ net.Listener
+ Logger Logger
+
+ delay time.Duration
+}
+
+func (ln *retryListener) Accept() (net.Conn, error) {
+ for {
+ conn, err := ln.Listener.Accept()
+ if ne, ok := err.(net.Error); ok && ne.Temporary() {
+ if ln.delay == 0 {
+ ln.delay = 5 * time.Millisecond
+ } else {
+ ln.delay *= 2
+ }
+ if max := 1 * time.Second; ln.delay > max {
+ ln.delay = max
+ }
+ if ln.Logger != nil {
+ ln.Logger.Printf("accept error (retrying in %v): %v", ln.delay, err)
+ }
+ time.Sleep(ln.delay)
+ } else {
+ ln.delay = 0
+ return conn, err
+ }
+ }
+}
+
+type Config struct {
+ Hostname string
+ Title string
+ LogPath string
+ MaxUserNetworks int
+ MultiUpstream bool
+ MOTD string
+ UpstreamUserIPs []*net.IPNet
+}
+
+type Server struct {
+ Logger Logger
+
+ config atomic.Value // *Config
+ db Database
+ stopWG sync.WaitGroup
+
+ lock sync.Mutex
+ listeners map[net.Listener]struct{}
+ users map[string]*user
+}
+
+func NewServer(db Database) *Server {
+ srv := &Server{
+ Logger: NewLogger(log.Writer(), true),
+ db: db,
+ listeners: make(map[net.Listener]struct{}),
+ users: make(map[string]*user),
+ }
+ srv.config.Store(&Config{
+ Hostname: "localhost",
+ MaxUserNetworks: -1,
+ MultiUpstream: true,
+ })
+ return srv
+}
+
+func (s *Server) prefix() *irc.Prefix {
+ return &irc.Prefix{Name: s.Config().Hostname}
+}
+
+func (s *Server) Config() *Config {
+ return s.config.Load().(*Config)
+}
+
+func (s *Server) SetConfig(cfg *Config) {
+ s.config.Store(cfg)
+}
+
+func (s *Server) Start() error {
+ users, err := s.db.ListUsers(context.TODO())
+ if err != nil {
+ return err
+ }
+
+ s.lock.Lock()
+ for i := range users {
+ s.addUserLocked(&users[i])
+ }
+ s.lock.Unlock()
+
+ return nil
+}
+
+func (s *Server) Shutdown() {
+ s.lock.Lock()
+ for ln := range s.listeners {
+ if err := ln.Close(); err != nil {
+ s.Logger.Printf("failed to stop listener: %v", err)
+ }
+ }
+ for _, u := range s.users {
+ u.events <- eventStop{}
+ }
+ s.lock.Unlock()
+
+ s.stopWG.Wait()
+
+ if err := s.db.Close(); err != nil {
+ s.Logger.Printf("failed to close DB: %v", err)
+ }
+}
+
+func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if _, ok := s.users[user.Username]; ok {
+ return nil, fmt.Errorf("user %q already exists", user.Username)
+ }
+
+ err := s.db.StoreUser(ctx, user)
+ if err != nil {
+ return nil, fmt.Errorf("could not create user in db: %v", err)
+ }
+
+ return s.addUserLocked(user), nil
+}
+
+func (s *Server) forEachUser(f func(*user)) {
+ s.lock.Lock()
+ for _, u := range s.users {
+ f(u)
+ }
+ s.lock.Unlock()
+}
+
+func (s *Server) getUser(name string) *user {
+ s.lock.Lock()
+ u := s.users[name]
+ s.lock.Unlock()
+ return u
+}
+
+func (s *Server) addUserLocked(user *User) *user {
+ s.Logger.Printf("starting bouncer for user %q", user.Username)
+ u := newUser(s, user)
+ s.users[u.Username] = u
+
+ s.stopWG.Add(1)
+
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving user %q: %v\n%v", user.Username, err, debug.Stack())
+ }
+
+ s.lock.Lock()
+ delete(s.users, u.Username)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ u.run()
+ }()
+
+ return u
+}
+
+var lastDownstreamID uint64 = 0
+
+func (s *Server) handle(ic ircConn) {
+ defer func() {
+ if err := recover(); err != nil {
+ s.Logger.Printf("panic serving downstream %q: %v\n%v", ic.RemoteAddr(), err, debug.Stack())
+ }
+ }()
+
+ id := atomic.AddUint64(&lastDownstreamID, 1)
+ dc := newDownstreamConn(s, ic, id)
+ if err := dc.runUntilRegistered(); err != nil {
+ if !errors.Is(err, io.EOF) {
+ dc.logger.Printf("%v", err)
+ }
+ } else {
+ dc.user.events <- eventDownstreamConnected{dc}
+ if err := dc.readMessages(dc.user.events); err != nil {
+ dc.logger.Printf("%v", err)
+ }
+ dc.user.events <- eventDownstreamDisconnected{dc}
+ }
+ dc.Close()
+}
+
+func (s *Server) Serve(ln net.Listener) error {
+ ln = &retryListener{
+ Listener: ln,
+ Logger: &prefixLogger{logger: s.Logger, prefix: fmt.Sprintf("listener %v: ", ln.Addr())},
+ }
+
+ s.lock.Lock()
+ s.listeners[ln] = struct{}{}
+ s.lock.Unlock()
+
+ s.stopWG.Add(1)
+
+ defer func() {
+ s.lock.Lock()
+ delete(s.listeners, ln)
+ s.lock.Unlock()
+
+ s.stopWG.Done()
+ }()
+
+ for {
+ conn, err := ln.Accept()
+ if isErrClosed(err) {
+ return nil
+ } else if err != nil {
+ return fmt.Errorf("failed to accept connection: %v", err)
+ }
+
+ go s.handle(newNetIRCConn(conn))
+ }
+}
+
+type ServerStats struct {
+ Users int
+ Downstreams int64
+ Upstreams int64
+}
+
+func (s *Server) Stats() *ServerStats {
+ var stats ServerStats
+ s.lock.Lock()
+ stats.Users = len(s.users)
+ s.lock.Unlock()
+ return &stats
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "net"
+ "testing"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+var testServerPrefix = &irc.Prefix{Name: "suika-test-server"}
+
+const (
+ testUsername = "suika-test-user"
+ testPassword = testUsername
+)
+
+func createTempSqliteDB(t *testing.T) Database {
+ db, err := OpenDB("sqlite3", ":memory:")
+ if err != nil {
+ t.Fatalf("failed to create temporary SQLite database: %v", err)
+ }
+ // :memory: will open a separate database for each new connection. Make
+ // sure the sql package only uses a single connection. An alternative
+ // solution is to use "file::memory:?cache=shared".
+ db.(*SqliteDB).db.SetMaxOpenConns(1)
+ return db
+}
+
+func createTempPostgresDB(t *testing.T) Database {
+ db := &PostgresDB{db: openTempPostgresDB(t)}
+ if err := db.upgrade(); err != nil {
+ t.Fatalf("failed to upgrade PostgreSQL database: %v", err)
+ }
+
+ return db
+}
+
+func createTestUser(t *testing.T, db Database) *User {
+ hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
+ if err != nil {
+ t.Fatalf("failed to generate bcrypt hash: %v", err)
+ }
+
+ record := &User{Username: testUsername, Password: string(hashed)}
+ if err := db.StoreUser(context.Background(), record); err != nil {
+ t.Fatalf("failed to store test user: %v", err)
+ }
+
+ return record
+}
+
+func createTestDownstream(t *testing.T, srv *Server) ircConn {
+ c1, c2 := net.Pipe()
+ go srv.handle(newNetIRCConn(c1))
+ return newNetIRCConn(c2)
+}
+
+func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("failed to create TCP listener: %v", err)
+ }
+
+ network := &Network{
+ Name: "testnet",
+ Addr: "irc://" + ln.Addr().String(),
+ Nick: user.Username,
+ Enabled: true,
+ }
+ if err := db.StoreNetwork(context.Background(), user.ID, network); err != nil {
+ t.Fatalf("failed to store test network: %v", err)
+ }
+
+ return network, ln
+}
+
+func mustAccept(t *testing.T, ln net.Listener) ircConn {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("failed accepting connection: %v", err)
+ }
+ return newNetIRCConn(c)
+}
+
+func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
+ msg, err := c.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message (want %q): %v", cmd, err)
+ }
+ if msg.Command != cmd {
+ t.Fatalf("invalid message received: want %q, got: %v", cmd, msg)
+ }
+ return msg
+}
+
+func registerDownstreamConn(t *testing.T, c ircConn, network *Network) {
+ c.WriteMessage(&irc.Message{
+ Command: "PASS",
+ Params: []string{testPassword},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "NICK",
+ Params: []string{testUsername},
+ })
+ c.WriteMessage(&irc.Message{
+ Command: "USER",
+ Params: []string{testUsername + "/" + network.Name, "0", "*", testUsername},
+ })
+
+ expectMessage(t, c, irc.RPL_WELCOME)
+}
+
+func registerUpstreamConn(t *testing.T, c ircConn) {
+ msg := expectMessage(t, c, "CAP")
+ if msg.Params[0] != "LS" {
+ t.Fatalf("invalid CAP LS: got: %v", msg)
+ }
+ msg = expectMessage(t, c, "NICK")
+ nick := msg.Params[0]
+ if nick != testUsername {
+ t.Fatalf("invalid NICK: want %q, got: %v", testUsername, msg)
+ }
+ expectMessage(t, c, "USER")
+
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_WELCOME,
+ Params: []string{nick, "Welcome!"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_YOURHOST,
+ Params: []string{nick, "Your host is suika-test-server"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_CREATED,
+ Params: []string{nick, "Who cares when the server was created?"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.RPL_MYINFO,
+ Params: []string{nick, testServerPrefix.Name, "suika", "aiwroO", "OovaimnqpsrtklbeI"},
+ })
+ c.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: irc.ERR_NOMOTD,
+ Params: []string{nick, "No MOTD"},
+ })
+}
+
+func testServer(t *testing.T, db Database) {
+ user := createTestUser(t, db)
+ network, upstream := createTestUpstream(t, db, user)
+ defer upstream.Close()
+
+ srv := NewServer(db)
+ if err := srv.Start(); err != nil {
+ t.Fatalf("failed to start server: %v", err)
+ }
+ defer srv.Shutdown()
+
+ uc := mustAccept(t, upstream)
+ defer uc.Close()
+ registerUpstreamConn(t, uc)
+
+ dc := createTestDownstream(t, srv)
+ defer dc.Close()
+ registerDownstreamConn(t, dc, network)
+
+ noticeText := "This is a very important server notice."
+ uc.WriteMessage(&irc.Message{
+ Prefix: testServerPrefix,
+ Command: "NOTICE",
+ Params: []string{testUsername, noticeText},
+ })
+
+ var msg *irc.Message
+ for {
+ var err error
+ msg, err = dc.ReadMessage()
+ if err != nil {
+ t.Fatalf("failed to read IRC message: %v", err)
+ }
+ if msg.Command == "NOTICE" {
+ break
+ }
+ }
+
+ if msg.Params[1] != noticeText {
+ t.Fatalf("invalid NOTICE text: want %q, got: %v", noticeText, msg)
+ }
+}
+
+func TestServer(t *testing.T) {
+ t.Run("sqlite", func(t *testing.T) {
+ db := createTempSqliteDB(t)
+ testServer(t, db)
+ })
+
+ t.Run("postgres", func(t *testing.T) {
+ db := createTempPostgresDB(t)
+ testServer(t, db)
+ })
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/hex"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+
+ "golang.org/x/crypto/bcrypt"
+ "gopkg.in/irc.v3"
+)
+
+const (
+ serviceNick = "BouncerServ"
+ serviceNickCM = "bouncerserv"
+ serviceRealname = "suika bouncer service"
+)
+
+// maxRSABits is the maximum number of RSA key bits used when generating a new
+// private key.
+const maxRSABits = 8192
+
+var servicePrefix = &irc.Prefix{
+ Name: serviceNick,
+ User: serviceNick,
+ Host: serviceNick,
+}
+
+type serviceCommandSet map[string]*serviceCommand
+
+type serviceCommand struct {
+ usage string
+ desc string
+ handle func(ctx context.Context, dc *downstreamConn, params []string) error
+ children serviceCommandSet
+ admin bool
+}
+
+func sendServiceNOTICE(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func sendServicePRIVMSG(dc *downstreamConn, text string) {
+ dc.SendMessage(&irc.Message{
+ Prefix: servicePrefix,
+ Command: "PRIVMSG",
+ Params: []string{dc.nick, text},
+ })
+}
+
+func splitWords(s string) ([]string, error) {
+ var words []string
+ var lastWord strings.Builder
+ escape := false
+ prev := ' '
+ wordDelim := ' '
+
+ for _, r := range s {
+ if escape {
+ // last char was a backslash, write the byte as-is.
+ lastWord.WriteRune(r)
+ escape = false
+ } else if r == '\\' {
+ escape = true
+ } else if wordDelim == ' ' && unicode.IsSpace(r) {
+ // end of last word
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ lastWord.Reset()
+ }
+ } else if r == wordDelim {
+ // wordDelim is either " or ', switch back to
+ // space-delimited words.
+ wordDelim = ' '
+ } else if r == '"' || r == '\'' {
+ if wordDelim == ' ' {
+ // start of (double-)quoted word
+ wordDelim = r
+ } else {
+ // either wordDelim is " and r is ' or vice-versa
+ lastWord.WriteRune(r)
+ }
+ } else {
+ lastWord.WriteRune(r)
+ }
+
+ prev = r
+ }
+
+ if !unicode.IsSpace(prev) {
+ words = append(words, lastWord.String())
+ }
+
+ if wordDelim != ' ' {
+ return nil, fmt.Errorf("unterminated quoted string")
+ }
+ if escape {
+ return nil, fmt.Errorf("unterminated backslash sequence")
+ }
+
+ return words, nil
+}
+
+func handleServicePRIVMSG(ctx context.Context, dc *downstreamConn, text string) {
+ words, err := splitWords(text)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: failed to parse command: %v`, err))
+ return
+ }
+
+ cmd, params, err := serviceCommands.Get(words)
+ if err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf(`error: %v (type "help" for a list of commands)`, err))
+ return
+ }
+ if cmd.admin && !dc.user.Admin {
+ sendServicePRIVMSG(dc, "error: you must be an admin to use this command")
+ return
+ }
+
+ if cmd.handle == nil {
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ // Pretend the command does not exist if it has neither children nor handler.
+ // This is obviously a bug but it is better to not die anyway.
+ dc.logger.Printf("command without handler and subcommands invoked:", words[0])
+ sendServicePRIVMSG(dc, fmt.Sprintf("command %q not found", words[0]))
+ }
+ return
+ }
+
+ if err := cmd.handle(ctx, dc, params); err != nil {
+ sendServicePRIVMSG(dc, fmt.Sprintf("error: %v", err))
+ }
+}
+
+func (cmds serviceCommandSet) Get(params []string) (*serviceCommand, []string, error) {
+ if len(params) == 0 {
+ return nil, nil, fmt.Errorf("no command specified")
+ }
+
+ name := params[0]
+ params = params[1:]
+
+ cmd, ok := cmds[name]
+ if !ok {
+ for k := range cmds {
+ if !strings.HasPrefix(k, name) {
+ continue
+ }
+ if cmd != nil {
+ return nil, params, fmt.Errorf("command %q is ambiguous", name)
+ }
+ cmd = cmds[k]
+ }
+ }
+ if cmd == nil {
+ return nil, params, fmt.Errorf("command %q not found", name)
+ }
+
+ if len(params) == 0 || len(cmd.children) == 0 {
+ return cmd, params, nil
+ }
+ return cmd.children.Get(params)
+}
+
+func (cmds serviceCommandSet) Names() []string {
+ l := make([]string, 0, len(cmds))
+ for name := range cmds {
+ l = append(l, name)
+ }
+ sort.Strings(l)
+ return l
+}
+
+var serviceCommands serviceCommandSet
+
+func init() {
+ serviceCommands = serviceCommandSet{
+ "help": {
+ usage: "[command]",
+ desc: "print help message",
+ handle: handleServiceHelp,
+ },
+ "network": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "add a new network",
+ handle: handleServiceNetworkCreate,
+ },
+ "status": {
+ desc: "show a list of saved networks and their current status",
+ handle: handleServiceNetworkStatus,
+ },
+ "update": {
+ usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-enabled enabled] [-connect-command command]...",
+ desc: "update a network",
+ handle: handleServiceNetworkUpdate,
+ },
+ "delete": {
+ usage: "[name]",
+ desc: "delete a network",
+ handle: handleServiceNetworkDelete,
+ },
+ "quote": {
+ usage: "[name] <command>",
+ desc: "send a raw line to a network",
+ handle: handleServiceNetworkQuote,
+ },
+ },
+ },
+ "certfp": {
+ children: serviceCommandSet{
+ "generate": {
+ usage: "[-key-type rsa|ecdsa|ed25519] [-bits N] [-network name]",
+ desc: "generate a new self-signed certificate, defaults to using RSA-3072 key",
+ handle: handleServiceCertFPGenerate,
+ },
+ "fingerprint": {
+ usage: "[-network name]",
+ desc: "show fingerprints of certificate",
+ handle: handleServiceCertFPFingerprints,
+ },
+ },
+ },
+ "sasl": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show SASL status",
+ handle: handleServiceSASLStatus,
+ },
+ "set-plain": {
+ usage: "[-network name] <username> <password>",
+ desc: "set SASL PLAIN credentials",
+ handle: handleServiceSASLSetPlain,
+ },
+ "reset": {
+ usage: "[-network name]",
+ desc: "disable SASL authentication and remove stored credentials",
+ handle: handleServiceSASLReset,
+ },
+ },
+ },
+ "user": {
+ children: serviceCommandSet{
+ "create": {
+ usage: "-username <username> -password <password> [-realname <realname>] [-admin]",
+ desc: "create a new suika user",
+ handle: handleUserCreate,
+ admin: true,
+ },
+ "update": {
+ usage: "[-password <password>] [-realname <realname>]",
+ desc: "update the current user",
+ handle: handleUserUpdate,
+ },
+ "delete": {
+ usage: "<username>",
+ desc: "delete a user",
+ handle: handleUserDelete,
+ admin: true,
+ },
+ },
+ },
+ "channel": {
+ children: serviceCommandSet{
+ "status": {
+ usage: "[-network name]",
+ desc: "show a list of saved channels and their current status",
+ handle: handleServiceChannelStatus,
+ },
+ "update": {
+ usage: "<name> [-relay-detached <default|none|highlight|message>] [-reattach-on <default|none|highlight|message>] [-detach-after <duration>] [-detach-on <default|none|highlight|message>]",
+ desc: "update a channel",
+ handle: handleServiceChannelUpdate,
+ },
+ },
+ },
+ "server": {
+ children: serviceCommandSet{
+ "status": {
+ desc: "show server statistics",
+ handle: handleServiceServerStatus,
+ admin: true,
+ },
+ "notice": {
+ desc: "broadcast a notice to all connected bouncer users",
+ handle: handleServiceServerNotice,
+ admin: true,
+ },
+ },
+ admin: true,
+ },
+ }
+}
+
+func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, admin bool, l *[]string) {
+ for _, name := range cmds.Names() {
+ cmd := cmds[name]
+ if cmd.admin && !admin {
+ continue
+ }
+ words := append(prefix, name)
+ if len(cmd.children) == 0 {
+ s := strings.Join(words, " ")
+ *l = append(*l, s)
+ } else {
+ appendServiceCommandSetHelp(cmd.children, words, admin, l)
+ }
+ }
+}
+
+func handleServiceHelp(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) > 0 {
+ cmd, rest, err := serviceCommands.Get(params)
+ if err != nil {
+ return err
+ }
+ words := params[:len(params)-len(rest)]
+
+ if len(cmd.children) > 0 {
+ var l []string
+ appendServiceCommandSetHelp(cmd.children, words, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ } else {
+ text := strings.Join(words, " ")
+ if cmd.usage != "" {
+ text += " " + cmd.usage
+ }
+ text += ": " + cmd.desc
+
+ sendServicePRIVMSG(dc, text)
+ }
+ } else {
+ var l []string
+ appendServiceCommandSetHelp(serviceCommands, nil, dc.user.Admin, &l)
+ sendServicePRIVMSG(dc, "available commands: "+strings.Join(l, ", "))
+ }
+ return nil
+}
+
+func newFlagSet() *flag.FlagSet {
+ fs := flag.NewFlagSet("", flag.ContinueOnError)
+ fs.SetOutput(ioutil.Discard)
+ return fs
+}
+
+type stringSliceFlag []string
+
+func (v *stringSliceFlag) String() string {
+ return fmt.Sprint([]string(*v))
+}
+
+func (v *stringSliceFlag) Set(s string) error {
+ *v = append(*v, s)
+ return nil
+}
+
+// stringPtrFlag is a flag value populating a string pointer. This allows to
+// disambiguate between a flag that hasn't been set and a flag that has been
+// set to an empty string.
+type stringPtrFlag struct {
+ ptr **string
+}
+
+func (f stringPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return ""
+ }
+ return **f.ptr
+}
+
+func (f stringPtrFlag) Set(s string) error {
+ *f.ptr = &s
+ return nil
+}
+
+type boolPtrFlag struct {
+ ptr **bool
+}
+
+func (f boolPtrFlag) String() string {
+ if f.ptr == nil || *f.ptr == nil {
+ return "<nil>"
+ }
+ return strconv.FormatBool(**f.ptr)
+}
+
+func (f boolPtrFlag) Set(s string) error {
+ v, err := strconv.ParseBool(s)
+ if err != nil {
+ return err
+ }
+ *f.ptr = &v
+ return nil
+}
+
+func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string, error) {
+ name, params := popArg(params)
+ if name == "" {
+ if dc.network == nil {
+ return nil, params, fmt.Errorf("no network selected, a name argument is required")
+ }
+ return dc.network, params, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, params, fmt.Errorf("unknown network %q", name)
+ }
+ return net, params, nil
+ }
+}
+
+type networkFlagSet struct {
+ *flag.FlagSet
+ Addr, Name, Nick, Username, Pass, Realname *string
+ Enabled *bool
+ ConnectCommands []string
+}
+
+func newNetworkFlagSet() *networkFlagSet {
+ fs := &networkFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.Addr}, "addr", "")
+ fs.Var(stringPtrFlag{&fs.Name}, "name", "")
+ fs.Var(stringPtrFlag{&fs.Nick}, "nick", "")
+ fs.Var(stringPtrFlag{&fs.Username}, "username", "")
+ fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
+ fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
+ fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "")
+ fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
+ return fs
+}
+
+func (fs *networkFlagSet) update(network *Network) error {
+ if fs.Addr != nil {
+ if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
+ scheme := addrParts[0]
+ switch scheme {
+ case "ircs", "irc", "unix":
+ default:
+ return fmt.Errorf("unknown scheme %q (supported schemes: ircs, irc, unix)", scheme)
+ }
+ }
+ network.Addr = *fs.Addr
+ }
+ if fs.Name != nil {
+ network.Name = *fs.Name
+ }
+ if fs.Nick != nil {
+ network.Nick = *fs.Nick
+ }
+ if fs.Username != nil {
+ network.Username = *fs.Username
+ }
+ if fs.Pass != nil {
+ network.Pass = *fs.Pass
+ }
+ if fs.Realname != nil {
+ network.Realname = *fs.Realname
+ }
+ if fs.Enabled != nil {
+ network.Enabled = *fs.Enabled
+ }
+ if fs.ConnectCommands != nil {
+ if len(fs.ConnectCommands) == 1 && fs.ConnectCommands[0] == "" {
+ network.ConnectCommands = nil
+ } else {
+ for _, command := range fs.ConnectCommands {
+ _, err := irc.ParseMessage(command)
+ if err != nil {
+ return fmt.Errorf("flag -connect-command must be a valid raw irc command string: %q: %v", command, err)
+ }
+ }
+ network.ConnectCommands = fs.ConnectCommands
+ }
+ }
+ return nil
+}
+
+func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if fs.Addr == nil {
+ return fmt.Errorf("flag -addr is required")
+ }
+
+ record := &Network{
+ Addr: *fs.Addr,
+ Enabled: true,
+ }
+ if err := fs.update(record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.createNetwork(ctx, record)
+ if err != nil {
+ return fmt.Errorf("could not create network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ n := 0
+ for _, net := range dc.user.networks {
+ var statuses []string
+ var details string
+ if uc := net.conn; uc != nil {
+ if dc.nick != uc.nick {
+ statuses = append(statuses, "connected as "+uc.nick)
+ } else {
+ statuses = append(statuses, "connected")
+ }
+ details = fmt.Sprintf("%v channels", uc.channels.Len())
+ } else if !net.Enabled {
+ statuses = append(statuses, "disabled")
+ } else {
+ statuses = append(statuses, "disconnected")
+ if net.lastError != nil {
+ details = net.lastError.Error()
+ }
+ }
+
+ if net == dc.network {
+ statuses = append(statuses, "current")
+ }
+
+ name := net.GetName()
+ if name != net.Addr {
+ name = fmt.Sprintf("%v (%v)", name, net.Addr)
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, strings.Join(statuses, ", "))
+ if details != "" {
+ s += ": " + details
+ }
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, `No network configured, add one with "network create".`)
+ }
+
+ return nil
+}
+
+func handleServiceNetworkUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ fs := newNetworkFlagSet()
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ record := net.Network // copy network record because we'll mutate it
+ if err := fs.update(&record); err != nil {
+ return err
+ }
+
+ network, err := dc.user.updateNetwork(ctx, &record)
+ if err != nil {
+ return fmt.Errorf("could not update network: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated network %q", network.GetName()))
+ return nil
+}
+
+func handleServiceNetworkDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ if err := dc.user.deleteNetwork(ctx, net.ID); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted network %q", net.GetName()))
+ return nil
+}
+
+func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 && len(params) != 2 {
+ return fmt.Errorf("expected one or two arguments")
+ }
+
+ raw := params[len(params)-1]
+ params = params[:len(params)-1]
+
+ net, params, err := getNetworkFromArg(dc, params)
+ if err != nil {
+ return err
+ }
+
+ uc := net.conn
+ if uc == nil {
+ return fmt.Errorf("network %q is not currently connected", net.GetName())
+ }
+
+ m, err := irc.ParseMessage(raw)
+ if err != nil {
+ return fmt.Errorf("failed to parse command %q: %v", raw, err)
+ }
+ uc.SendMessage(ctx, m)
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName()))
+ return nil
+}
+
+func sendCertfpFingerprints(dc *downstreamConn, cert []byte) {
+ sha1Sum := sha1.Sum(cert)
+ sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
+ sha256Sum := sha256.Sum256(cert)
+ sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))
+ sha512Sum := sha512.Sum512(cert)
+ sendServicePRIVMSG(dc, "SHA-512 fingerprint: "+hex.EncodeToString(sha512Sum[:]))
+}
+
+func getNetworkFromFlag(dc *downstreamConn, name string) (*network, error) {
+ if name == "" {
+ if dc.network == nil {
+ return nil, fmt.Errorf("no network selected, -network is required")
+ }
+ return dc.network, nil
+ } else {
+ net := dc.user.getNetwork(name)
+ if net == nil {
+ return nil, fmt.Errorf("unknown network %q", name)
+ }
+ return net, nil
+ }
+}
+
+func handleServiceCertFPGenerate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+ keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)")
+ bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if *bits <= 0 || *bits > maxRSABits {
+ return fmt.Errorf("invalid value for -bits")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ privKey, cert, err := generateCertFP(*keyType, *bits)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.External.CertBlob = cert
+ net.SASL.External.PrivKeyBlob = privKey
+ net.SASL.Mechanism = "EXTERNAL"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "certificate generated")
+ sendCertfpFingerprints(dc, cert)
+ return nil
+}
+
+func handleServiceCertFPFingerprints(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ if net.SASL.Mechanism != "EXTERNAL" {
+ return fmt.Errorf("CertFP not set up")
+ }
+
+ sendCertfpFingerprints(dc, net.SASL.External.CertBlob)
+ return nil
+}
+
+func handleServiceSASLStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ switch net.SASL.Mechanism {
+ case "PLAIN":
+ sendServicePRIVMSG(dc, fmt.Sprintf("SASL PLAIN enabled with username %q", net.SASL.Plain.Username))
+ case "EXTERNAL":
+ sendServicePRIVMSG(dc, "SASL EXTERNAL (CertFP) enabled")
+ case "":
+ sendServicePRIVMSG(dc, "SASL is disabled")
+ }
+
+ if uc := net.conn; uc != nil {
+ if uc.account != "" {
+ sendServicePRIVMSG(dc, fmt.Sprintf("Authenticated on upstream network with account %q", uc.account))
+ } else {
+ sendServicePRIVMSG(dc, "Unauthenticated on upstream network")
+ }
+ } else {
+ sendServicePRIVMSG(dc, "Disconnected from upstream network")
+ }
+
+ return nil
+}
+
+func handleServiceSASLSetPlain(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ if len(fs.Args()) != 2 {
+ return fmt.Errorf("expected exactly 2 arguments")
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = fs.Arg(0)
+ net.SASL.Plain.Password = fs.Arg(1)
+ net.SASL.Mechanism = "PLAIN"
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials saved")
+ return nil
+}
+
+func handleServiceSASLReset(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ netName := fs.String("network", "", "select a network")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ net, err := getNetworkFromFlag(dc, *netName)
+ if err != nil {
+ return err
+ }
+
+ net.SASL.Plain.Username = ""
+ net.SASL.Plain.Password = ""
+ net.SASL.External.CertBlob = nil
+ net.SASL.External.PrivKeyBlob = nil
+ net.SASL.Mechanism = ""
+
+ if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &net.Network); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, "credentials reset")
+ return nil
+}
+
+func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string) error {
+ fs := newFlagSet()
+ username := fs.String("username", "", "")
+ password := fs.String("password", "", "")
+ realname := fs.String("realname", "", "")
+ admin := fs.Bool("admin", false, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if *username == "" {
+ return fmt.Errorf("flag -username is required")
+ }
+ if *password == "" {
+ return fmt.Errorf("flag -password is required")
+ }
+
+ hashed, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+
+ user := &User{
+ Username: *username,
+ Password: string(hashed),
+ Realname: *realname,
+ Admin: *admin,
+ }
+ if _, err := dc.srv.createUser(ctx, user); err != nil {
+ return fmt.Errorf("could not create user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("created user %q", *username))
+ return nil
+}
+
+func popArg(params []string) (string, []string) {
+ if len(params) > 0 && !strings.HasPrefix(params[0], "-") {
+ return params[0], params[1:]
+ }
+ return "", params
+}
+
+func handleUserUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ var password, realname *string
+ var admin *bool
+ fs := newFlagSet()
+ fs.Var(stringPtrFlag{&password}, "password", "")
+ fs.Var(stringPtrFlag{&realname}, "realname", "")
+ fs.Var(boolPtrFlag{&admin}, "admin", "")
+
+ username, params := popArg(params)
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+ if len(fs.Args()) > 0 {
+ return fmt.Errorf("unexpected argument")
+ }
+
+ var hashed *string
+ if password != nil {
+ hashedBytes, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
+ if err != nil {
+ return fmt.Errorf("failed to hash password: %v", err)
+ }
+ hashedStr := string(hashedBytes)
+ hashed = &hashedStr
+ }
+
+ if username != "" && username != dc.user.Username {
+ if !dc.user.Admin {
+ return fmt.Errorf("you must be an admin to update other users")
+ }
+ if realname != nil {
+ return fmt.Errorf("cannot update -realname of other user")
+ }
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ done := make(chan error, 1)
+ event := eventUserUpdate{
+ password: hashed,
+ admin: admin,
+ done: done,
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case u.events <- event:
+ }
+ // TODO: send context to the other side
+ if err := <-done; err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", username))
+ } else {
+ // copy the user record because we'll mutate it
+ record := dc.user.User
+
+ if hashed != nil {
+ record.Password = *hashed
+ }
+ if realname != nil {
+ record.Realname = *realname
+ }
+ if admin != nil {
+ return fmt.Errorf("cannot update -admin of own user")
+ }
+
+ if err := dc.user.updateUser(ctx, &record); err != nil {
+ return err
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated user %q", dc.user.Username))
+ }
+
+ return nil
+}
+
+func handleUserDelete(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ username := params[0]
+
+ u := dc.srv.getUser(username)
+ if u == nil {
+ return fmt.Errorf("unknown username %q", username)
+ }
+
+ u.stop()
+
+ if err := dc.srv.db.DeleteUser(ctx, u.ID); err != nil {
+ return fmt.Errorf("failed to delete user: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("deleted user %q", username))
+ return nil
+}
+
+func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ var defaultNetworkName string
+ if dc.network != nil {
+ defaultNetworkName = dc.network.GetName()
+ }
+
+ fs := newFlagSet()
+ networkName := fs.String("network", defaultNetworkName, "")
+
+ if err := fs.Parse(params); err != nil {
+ return err
+ }
+
+ n := 0
+
+ sendNetwork := func(net *network) {
+ var channels []*Channel
+ for _, entry := range net.channels.innerMap {
+ channels = append(channels, entry.value.(*Channel))
+ }
+
+ sort.Slice(channels, func(i, j int) bool {
+ return strings.ReplaceAll(channels[i].Name, "#", "") <
+ strings.ReplaceAll(channels[j].Name, "#", "")
+ })
+
+ for _, ch := range channels {
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+ }
+
+ name := ch.Name
+ if *networkName == "" {
+ name += "/" + net.GetName()
+ }
+
+ var status string
+ if uch != nil {
+ status = "joined"
+ } else if net.conn != nil {
+ status = "parted"
+ } else {
+ status = "disconnected"
+ }
+
+ if ch.Detached {
+ status += ", detached"
+ }
+
+ s := fmt.Sprintf("%v [%v]", name, status)
+ sendServicePRIVMSG(dc, s)
+
+ n++
+ }
+ }
+
+ if *networkName == "" {
+ for _, net := range dc.user.networks {
+ sendNetwork(net)
+ }
+ } else {
+ net := dc.user.getNetwork(*networkName)
+ if net == nil {
+ return fmt.Errorf("unknown network %q", *networkName)
+ }
+ sendNetwork(net)
+ }
+
+ if n == 0 {
+ sendServicePRIVMSG(dc, "No channel configured.")
+ }
+
+ return nil
+}
+
+type channelFlagSet struct {
+ *flag.FlagSet
+ RelayDetached, ReattachOn, DetachAfter, DetachOn *string
+}
+
+func newChannelFlagSet() *channelFlagSet {
+ fs := &channelFlagSet{FlagSet: newFlagSet()}
+ fs.Var(stringPtrFlag{&fs.RelayDetached}, "relay-detached", "")
+ fs.Var(stringPtrFlag{&fs.ReattachOn}, "reattach-on", "")
+ fs.Var(stringPtrFlag{&fs.DetachAfter}, "detach-after", "")
+ fs.Var(stringPtrFlag{&fs.DetachOn}, "detach-on", "")
+ return fs
+}
+
+func (fs *channelFlagSet) update(channel *Channel) error {
+ if fs.RelayDetached != nil {
+ filter, err := parseFilter(*fs.RelayDetached)
+ if err != nil {
+ return err
+ }
+ channel.RelayDetached = filter
+ }
+ if fs.ReattachOn != nil {
+ filter, err := parseFilter(*fs.ReattachOn)
+ if err != nil {
+ return err
+ }
+ channel.ReattachOn = filter
+ }
+ if fs.DetachAfter != nil {
+ dur, err := time.ParseDuration(*fs.DetachAfter)
+ if err != nil || dur < 0 {
+ return fmt.Errorf("unknown duration for -detach-after %q (duration format: 0, 300s, 22h30m, ...)", *fs.DetachAfter)
+ }
+ channel.DetachAfter = dur
+ }
+ if fs.DetachOn != nil {
+ filter, err := parseFilter(*fs.DetachOn)
+ if err != nil {
+ return err
+ }
+ channel.DetachOn = filter
+ }
+ return nil
+}
+
+func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) < 1 {
+ return fmt.Errorf("expected at least one argument")
+ }
+ name := params[0]
+
+ fs := newChannelFlagSet()
+ if err := fs.Parse(params[1:]); err != nil {
+ return err
+ }
+
+ uc, upstreamName, err := dc.unmarshalEntity(name)
+ if err != nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ ch := uc.network.channels.Value(upstreamName)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+
+ if err := fs.update(ch); err != nil {
+ return err
+ }
+
+ uc.updateChannelAutoDetach(upstreamName)
+
+ if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ return fmt.Errorf("failed to update channel: %v", err)
+ }
+
+ sendServicePRIVMSG(dc, fmt.Sprintf("updated channel %q", name))
+ return nil
+}
+func handleServiceServerStatus(ctx context.Context, dc *downstreamConn, params []string) error {
+ dbStats, err := dc.user.srv.db.Stats(ctx)
+ if err != nil {
+ return err
+ }
+ serverStats := dc.user.srv.Stats()
+ sendServicePRIVMSG(dc, fmt.Sprintf("%v/%v users, %v downstreams, %v upstreams, %v networks, %v channels", serverStats.Users, dbStats.Users, serverStats.Downstreams, serverStats.Upstreams, dbStats.Networks, dbStats.Channels))
+ return nil
+}
+
+func handleServiceServerNotice(ctx context.Context, dc *downstreamConn, params []string) error {
+ if len(params) != 1 {
+ return fmt.Errorf("expected exactly one argument")
+ }
+ text := params[0]
+
+ dc.logger.Printf("broadcasting bouncer-wide NOTICE: %v", text)
+
+ broadcastMsg := &irc.Message{
+ Prefix: servicePrefix,
+ Command: "NOTICE",
+ Params: []string{"$" + dc.srv.Config().Hostname, text},
+ }
+ var err error
+ sent := 0
+ total := 0
+ dc.srv.forEachUser(func(u *user) {
+ total++
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ case u.events <- eventBroadcast{broadcastMsg}:
+ sent++
+ }
+ })
+
+ dc.logger.Printf("broadcast bouncer-wide NOTICE to %v/%v downstreams", sent, total)
+ sendServicePRIVMSG(dc, fmt.Sprintf("sent to %v/%v downstream connections", sent, total))
+
+ return err
+}
--- /dev/null
+package suika
+
+import (
+ "testing"
+)
+
+func assertSplit(t *testing.T, input string, expected []string) {
+ actual, err := splitWords(input)
+ if err != nil {
+ t.Errorf("%q: %v", input, err)
+ return
+ }
+ if len(actual) != len(expected) {
+ t.Errorf("%q: expected %d words, got %d\nexpected: %v\ngot: %v", input, len(expected), len(actual), expected, actual)
+ return
+ }
+ for i := 0; i < len(actual); i++ {
+ if actual[i] != expected[i] {
+ t.Errorf("%q: expected word #%d to be %q, got %q\nexpected: %v\ngot: %v", input, i, expected[i], actual[i], expected, actual)
+ }
+ }
+}
+
+func TestSplit(t *testing.T) {
+ assertSplit(t, " ch 'up' #suika 'relay'-det\"ache\"d message ", []string{
+ "ch",
+ "up",
+ "#suika",
+ "relay-detached",
+ "message",
+ })
+ assertSplit(t, "net update \\\"free\\\"node -pass 'political \"stance\" desu!' -realname '' -nick lee", []string{
+ "net",
+ "update",
+ "\"free\"node",
+ "-pass",
+ "political \"stance\" desu!",
+ "-realname",
+ "",
+ "-nick",
+ "lee",
+ })
+ assertSplit(t, "Omedeto,\\ Yui! ''", []string{
+ "Omedeto, Yui!",
+ "",
+ })
+
+ if _, err := splitWords("end of 'file"); err == nil {
+ t.Errorf("expected error on unterminated single quote")
+ }
+ if _, err := splitWords("end of backquote \\"); err == nil {
+ t.Errorf("expected error on unterminated backquote sequence")
+ }
+}
--- /dev/null
+CREATE TABLE IF NOT EXISTS "User" (
+ id SERIAL PRIMARY KEY,
+ username VARCHAR(255) NOT NULL UNIQUE,
+ password VARCHAR(255),
+ admin BOOLEAN NOT NULL DEFAULT FALSE,
+ realname VARCHAR(255)
+);
+
+CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
+
+CREATE TABLE IF NOT EXISTS "Network" (
+ id SERIAL PRIMARY KEY,
+ name VARCHAR(255),
+ "user" INTEGER NOT NULL REFERENCES "User"(id) ON DELETE CASCADE,
+ addr VARCHAR(255) NOT NULL,
+ nick VARCHAR(255),
+ username VARCHAR(255),
+ realname VARCHAR(255),
+ pass VARCHAR(255),
+ connect_commands VARCHAR(1023),
+ sasl_mechanism sasl_mechanism,
+ sasl_plain_username VARCHAR(255),
+ sasl_plain_password VARCHAR(255),
+ sasl_external_cert BYTEA,
+ sasl_external_key BYTEA,
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ UNIQUE("user", addr, nick),
+ UNIQUE("user", name)
+);
+CREATE TABLE IF NOT EXISTS "Channel" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ name VARCHAR(255) NOT NULL,
+ key VARCHAR(255),
+ detached BOOLEAN NOT NULL DEFAULT FALSE,
+ detached_internal_msgid VARCHAR(255),
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(network, name)
+);
+CREATE TABLE IF NOT EXISTS "DeliveryReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ client VARCHAR(255) NOT NULL DEFAULT '',
+ internal_msgid VARCHAR(255) NOT NULL,
+ UNIQUE(network, target, client)
+);
+CREATE TABLE IF NOT EXISTS "ReadReceipt" (
+ id SERIAL PRIMARY KEY,
+ network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
+ target VARCHAR(255) NOT NULL,
+ timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
+ UNIQUE(network, target)
+);
+
--- /dev/null
+CREATE TABLE IF NOT EXISTS User (
+ id INTEGER PRIMARY KEY,
+ username TEXT NOT NULL UNIQUE,
+ password TEXT,
+ admin INTEGER NOT NULL DEFAULT 0,
+ realname TEXT
+);
+CREATE TABLE IF NOT EXISTS Network (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ user INTEGER NOT NULL,
+ addr TEXT NOT NULL,
+ nick TEXT,
+ username TEXT,
+ realname TEXT,
+ pass TEXT,
+ connect_commands TEXT,
+ sasl_mechanism TEXT,
+ sasl_plain_username TEXT,
+ sasl_plain_password TEXT,
+ sasl_external_cert BLOB,
+ sasl_external_key BLOB,
+ enabled INTEGER NOT NULL DEFAULT 1,
+ FOREIGN KEY(user) REFERENCES User(id),
+ UNIQUE(user, addr, nick),
+ UNIQUE(user, name)
+);
+CREATE TABLE IF NOT EXISTS Channel (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ name TEXT NOT NULL,
+ key TEXT,
+ detached INTEGER NOT NULL DEFAULT 0,
+ detached_internal_msgid TEXT,
+ relay_detached INTEGER NOT NULL DEFAULT 0,
+ reattach_on INTEGER NOT NULL DEFAULT 0,
+ detach_after INTEGER NOT NULL DEFAULT 0,
+ detach_on INTEGER NOT NULL DEFAULT 0,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, name)
+);
+
+CREATE TABLE IF NOT EXISTS DeliveryReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ client TEXT,
+ internal_msgid TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target, client)
+);
+
+CREATE TABLE IF NOT EXISTS ReadReceipt (
+ id INTEGER PRIMARY KEY,
+ network INTEGER NOT NULL,
+ target TEXT NOT NULL,
+ timestamp TEXT NOT NULL,
+ FOREIGN KEY(network) REFERENCES Network(id),
+ UNIQUE(network, target)
+);
+
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto"
+ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/emersion/go-sasl"
+ "gopkg.in/irc.v3"
+)
+
+// permanentUpstreamCaps is the static list of upstream capabilities always
+// requested when supported.
+var permanentUpstreamCaps = map[string]bool{
+ "account-notify": true,
+ "account-tag": true,
+ "away-notify": true,
+ "batch": true,
+ "extended-join": true,
+ "invite-notify": true,
+ "labeled-response": true,
+ "message-tags": true,
+ "multi-prefix": true,
+ "sasl": true,
+ "server-time": true,
+ "setname": true,
+
+ "draft/account-registration": true,
+ "draft/extended-monitor": true,
+}
+
+type registrationError struct {
+ *irc.Message
+}
+
+func (err registrationError) Error() string {
+ return fmt.Sprintf("registration error (%v): %v", err.Command, err.Reason())
+}
+
+func (err registrationError) Reason() string {
+ if len(err.Params) > 0 {
+ return err.Params[len(err.Params)-1]
+ }
+ return err.Command
+}
+
+func (err registrationError) Temporary() bool {
+ // Only return false if we're 100% sure that fixing the error requires a
+ // network configuration change
+ switch err.Command {
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME:
+ return false
+ case "FAIL":
+ return err.Params[1] != "ACCOUNT_REQUIRED"
+ default:
+ return true
+ }
+}
+
+type upstreamChannel struct {
+ Name string
+ conn *upstreamConn
+ Topic string
+ TopicWho *irc.Prefix
+ TopicTime time.Time
+ Status channelStatus
+ modes channelModes
+ creationTime string
+ Members membershipsCasemapMap
+ complete bool
+ detachTimer *time.Timer
+}
+
+func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) {
+ if uc.detachTimer != nil {
+ uc.detachTimer.Stop()
+ uc.detachTimer = nil
+ }
+
+ if dur == 0 {
+ return
+ }
+
+ uc.detachTimer = time.AfterFunc(dur, func() {
+ uc.conn.network.user.events <- eventChannelDetach{
+ uc: uc.conn,
+ name: uc.Name,
+ }
+ })
+}
+
+type pendingUpstreamCommand struct {
+ downstreamID uint64
+ msg *irc.Message
+}
+
+type upstreamConn struct {
+ conn
+
+ network *network
+ user *user
+
+ serverName string
+ availableUserModes string
+ availableChannelModes map[byte]channelModeType
+ availableChannelTypes string
+ availableMemberships []membership
+ isupport map[string]*string
+
+ registered bool
+ nick string
+ nickCM string
+ username string
+ realname string
+ modes userModes
+ channels upstreamChannelCasemapMap
+ supportedCaps map[string]string
+ caps map[string]bool
+ batches map[string]batch
+ away bool
+ account string
+ nextLabelID uint64
+ monitored monitorCasemapMap
+
+ saslClient sasl.Client
+ saslStarted bool
+
+ casemapIsSet bool
+
+ // Queue of commands in progress, indexed by type. The first entry has been
+ // sent to the server and is awaiting reply. The following entries have not
+ // been sent yet.
+ pendingCmds map[string][]pendingUpstreamCommand
+
+ gotMotd bool
+}
+
+func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
+ logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
+
+ dialer := net.Dialer{Timeout: connectTimeout}
+
+ u, err := network.URL()
+ if err != nil {
+ return nil, err
+ }
+
+ var netConn net.Conn
+ switch u.Scheme {
+ case "ircs":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6697"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to TLS server at address %q", addr)
+
+ tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}}
+ if network.SASL.Mechanism == "EXTERNAL" {
+ if network.SASL.External.CertBlob == nil {
+ return nil, fmt.Errorf("missing certificate for authentication")
+ }
+ if network.SASL.External.PrivKeyBlob == nil {
+ return nil, fmt.Errorf("missing private key for authentication")
+ }
+ key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse private key: %v", err)
+ }
+ tlsConfig.Certificates = []tls.Certificate{
+ {
+ Certificate: [][]byte{network.SASL.External.CertBlob},
+ PrivateKey: key.(crypto.PrivateKey),
+ },
+ }
+ logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
+ }
+
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+
+ // Don't do the TLS handshake immediately, because we need to register
+ // the new connection with identd ASAP.
+ netConn = tls.Client(netConn, tlsConfig)
+ case "irc":
+ addr := u.Host
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = u.Host
+ addr = u.Host + ":6667"
+ }
+
+ dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
+ }
+
+ logger.Printf("connecting to plain-text server at address %q", addr)
+ netConn, err = dialer.DialContext(ctx, "tcp", addr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
+ }
+ case "irc+unix", "unix":
+ logger.Printf("connecting to Unix socket at path %q", u.Path)
+ netConn, err = dialer.DialContext(ctx, "unix", u.Path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
+ }
+ default:
+ return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
+ }
+
+ options := connOptions{
+ Logger: logger,
+ RateLimitDelay: upstreamMessageDelay,
+ RateLimitBurst: upstreamMessageBurst,
+ }
+
+ uc := &upstreamConn{
+ conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
+ network: network,
+ user: network.user,
+ channels: upstreamChannelCasemapMap{newCasemapMap(0)},
+ supportedCaps: make(map[string]string),
+ caps: make(map[string]bool),
+ batches: make(map[string]batch),
+ availableChannelTypes: stdChannelTypes,
+ availableChannelModes: stdChannelModes,
+ availableMemberships: stdMemberships,
+ isupport: make(map[string]*string),
+ pendingCmds: make(map[string][]pendingUpstreamCommand),
+ monitored: monitorCasemapMap{newCasemapMap(0)},
+ }
+ return uc, nil
+}
+
+func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
+ uc.network.forEachDownstream(f)
+}
+
+func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if id != 0 && id != dc.id {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn {
+ for _, dc := range uc.user.downstreamConns {
+ if dc.id == id {
+ return dc
+ }
+ }
+ return nil
+}
+
+func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ return nil, fmt.Errorf("unknown channel %q", name)
+ }
+ return ch, nil
+}
+
+func (uc *upstreamConn) isChannel(entity string) bool {
+ return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0]))
+}
+
+func (uc *upstreamConn) isOurNick(nick string) bool {
+ return uc.nickCM == uc.network.casemap(nick)
+}
+
+func (uc *upstreamConn) abortPendingCommands() {
+ for _, l := range uc.pendingCmds {
+ for _, pendingCmd := range l {
+ dc := uc.downstreamByID(pendingCmd.downstreamID)
+ if dc == nil {
+ continue
+ }
+
+ switch pendingCmd.msg.Command {
+ case "LIST":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "Command aborted"},
+ })
+ case "WHO":
+ mask := "*"
+ if len(pendingCmd.msg.Params) > 0 {
+ mask = pendingCmd.msg.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "Command aborted"},
+ })
+ case "AUTHENTICATE":
+ dc.endSASL(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.ERR_SASLABORTED,
+ Params: []string{dc.nick, "SASL authentication aborted"},
+ })
+ case "REGISTER", "VERIFY":
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "FAIL",
+ Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"},
+ })
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command))
+ }
+ }
+ }
+
+ uc.pendingCmds = make(map[string][]pendingUpstreamCommand)
+}
+
+func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return
+ }
+ uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg)
+}
+
+func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
+ switch msg.Command {
+ case "LIST", "WHO", "AUTHENTICATE", "REGISTER", "VERIFY":
+ // Supported
+ default:
+ panic(fmt.Errorf("Unsupported pending command %q", msg.Command))
+ }
+
+ uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{
+ downstreamID: dc.id,
+ msg: msg,
+ })
+
+ if len(uc.pendingCmds[msg.Command]) == 1 {
+ uc.sendNextPendingCommand(msg.Command)
+ }
+}
+
+func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) {
+ if len(uc.pendingCmds[cmd]) == 0 {
+ return nil, nil
+ }
+
+ pendingCmd := uc.pendingCmds[cmd][0]
+ return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg
+}
+
+func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) {
+ dc, msg := uc.currentPendingCommand(cmd)
+
+ if len(uc.pendingCmds[cmd]) > 0 {
+ copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:])
+ uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1]
+ }
+
+ uc.sendNextPendingCommand(cmd)
+
+ return dc, msg
+}
+
+func (uc *upstreamConn) cancelPendingCommandsByDownstreamID(downstreamID uint64) {
+ for cmd := range uc.pendingCmds {
+ // We can't cancel the currently running command stored in
+ // uc.pendingCmds[cmd][0]
+ for i := len(uc.pendingCmds[cmd]) - 1; i >= 1; i-- {
+ if uc.pendingCmds[cmd][i].downstreamID == downstreamID {
+ uc.pendingCmds[cmd] = append(uc.pendingCmds[cmd][:i], uc.pendingCmds[cmd][i+1:]...)
+ }
+ }
+ }
+}
+
+func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) {
+ memberships := make(memberships, 0, 4)
+ i := 0
+ for _, m := range uc.availableMemberships {
+ if i >= len(s) {
+ break
+ }
+ if s[i] == m.Prefix {
+ memberships = append(memberships, m)
+ i++
+ }
+ }
+ return &memberships, s[i:]
+}
+
+func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) error {
+ var label string
+ if l, ok := msg.GetTag("label"); ok {
+ label = l
+ delete(msg.Tags, "label")
+ }
+
+ var msgBatch *batch
+ if batchName, ok := msg.GetTag("batch"); ok {
+ b, ok := uc.batches[batchName]
+ if !ok {
+ return fmt.Errorf("unexpected batch reference: batch was not defined: %q", batchName)
+ }
+ msgBatch = &b
+ if label == "" {
+ label = msgBatch.Label
+ }
+ delete(msg.Tags, "batch")
+ }
+
+ var downstreamID uint64 = 0
+ if label != "" {
+ var labelOffset uint64
+ n, err := fmt.Sscanf(label, "sd-%d-%d", &downstreamID, &labelOffset)
+ if err == nil && n < 2 {
+ err = errors.New("not enough arguments")
+ }
+ if err != nil {
+ return fmt.Errorf("unexpected message label: invalid downstream reference for label %q: %v", label, err)
+ }
+ }
+
+ if _, ok := msg.Tags["time"]; !ok {
+ msg.Tags["time"] = irc.TagValue(formatServerTime(time.Now()))
+ }
+
+ switch msg.Command {
+ case "PING":
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PONG",
+ Params: msg.Params,
+ })
+ return nil
+ case "NOTICE", "PRIVMSG", "TAGMSG":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var entity, text string
+ if msg.Command != "TAGMSG" {
+ if err := parseMessageParams(msg, &entity, &text); err != nil {
+ return err
+ }
+ } else {
+ if err := parseMessageParams(msg, &entity); err != nil {
+ return err
+ }
+ }
+
+ if msg.Prefix.Name == serviceNick {
+ uc.logger.Printf("skipping %v from suika's service: %v", msg.Command, msg)
+ break
+ }
+ if entity == serviceNick {
+ uc.logger.Printf("skipping %v to suika's service: %v", msg.Command, msg)
+ break
+ }
+
+ if msg.Prefix.User == "" && msg.Prefix.Host == "" { // server message
+ uc.produce("", msg, nil)
+ } else { // regular user message
+ target := entity
+ if uc.isOurNick(target) {
+ target = msg.Prefix.Name
+ }
+
+ ch := uc.network.channels.Value(target)
+ if ch != nil && msg.Command != "TAGMSG" {
+ if ch.Detached {
+ uc.handleDetachedMessage(ctx, ch, msg)
+ }
+
+ highlight := uc.network.isHighlight(msg)
+ if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) {
+ uc.updateChannelAutoDetach(target)
+ }
+ }
+
+ uc.produce(target, msg, nil)
+ }
+ case "CAP":
+ var subCmd string
+ if err := parseMessageParams(msg, nil, &subCmd); err != nil {
+ return err
+ }
+ subCmd = strings.ToUpper(subCmd)
+ subParams := msg.Params[2:]
+ switch subCmd {
+ case "LS":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := subParams[len(subParams)-1]
+ more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
+
+ uc.handleSupportedCaps(caps)
+
+ if more {
+ break // wait to receive all capabilities
+ }
+
+ uc.requestCaps()
+
+ if uc.requestSASL() {
+ break // we'll send CAP END after authentication is completed
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ case "ACK", "NAK":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, name := range caps {
+ if err := uc.handleCapAck(ctx, strings.ToLower(name), subCmd == "ACK"); err != nil {
+ return err
+ }
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ case "NEW":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ uc.handleSupportedCaps(subParams[0])
+ uc.requestCaps()
+ case "DEL":
+ if len(subParams) < 1 {
+ return newNeedMoreParamsError(msg.Command)
+ }
+ caps := strings.Fields(subParams[0])
+
+ for _, c := range caps {
+ delete(uc.supportedCaps, c)
+ delete(uc.caps, c)
+ }
+
+ if uc.registered {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+ }
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+ }
+ case "AUTHENTICATE":
+ if uc.saslClient == nil {
+ return fmt.Errorf("received unexpected AUTHENTICATE message")
+ }
+
+ // TODO: if a challenge is 400 bytes long, buffer it
+ var challengeStr string
+ if err := parseMessageParams(msg, &challengeStr); err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ var challenge []byte
+ if challengeStr != "+" {
+ var err error
+ challenge, err = base64.StdEncoding.DecodeString(challengeStr)
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+ }
+
+ var resp []byte
+ var err error
+ if !uc.saslStarted {
+ _, resp, err = uc.saslClient.Start()
+ uc.saslStarted = true
+ } else {
+ resp, err = uc.saslClient.Next(challenge)
+ }
+ if err != nil {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{"*"},
+ })
+ return err
+ }
+
+ // <= instead of < because we need to send a final empty response if
+ // the last chunk is exactly 400 bytes long
+ for i := 0; i <= len(resp); i += maxSASLLength {
+ j := i + maxSASLLength
+ if j > len(resp) {
+ j = len(resp)
+ }
+
+ chunk := resp[i:j]
+
+ var respStr = "+"
+ if len(chunk) != 0 {
+ respStr = base64.StdEncoding.EncodeToString(chunk)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{respStr},
+ })
+ }
+ case irc.RPL_LOGGEDIN:
+ if err := parseMessageParams(msg, nil, nil, &uc.account); err != nil {
+ return err
+ }
+ uc.logger.Printf("logged in with account %q", uc.account)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.RPL_LOGGEDOUT:
+ uc.account = ""
+ uc.logger.Printf("logged out")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateAccount()
+ })
+ case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED:
+ var info string
+ if err := parseMessageParams(msg, nil, &info); err != nil {
+ return err
+ }
+ switch msg.Command {
+ case irc.ERR_NICKLOCKED:
+ uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
+ case irc.ERR_SASLFAIL:
+ uc.logger.Printf("SASL authentication failed: %v", info)
+ case irc.ERR_SASLTOOLONG:
+ uc.logger.Printf("SASL message too long: %v", info)
+ }
+
+ uc.saslClient = nil
+ uc.saslStarted = false
+
+ if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
+ if msg.Command == irc.RPL_SASLSUCCESS {
+ uc.network.autoSaveSASLPlain(ctx, dc.sasl.plainUsername, dc.sasl.plainPassword)
+ }
+
+ dc.endSASL(msg)
+ }
+
+ if !uc.registered {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"END"},
+ })
+ }
+ case "REGISTER", "VERIFY":
+ if dc, cmd := uc.dequeueCommand(msg.Command); dc != nil {
+ if msg.Command == "REGISTER" {
+ var account, password string
+ if err := parseMessageParams(msg, nil, &account); err != nil {
+ return err
+ }
+ if err := parseMessageParams(cmd, nil, nil, &password); err != nil {
+ return err
+ }
+ uc.network.autoSaveSASLPlain(ctx, account, password)
+ }
+
+ dc.SendMessage(msg)
+ }
+ case irc.RPL_WELCOME:
+ if err := parseMessageParams(msg, &uc.nick); err != nil {
+ return err
+ }
+
+ uc.registered = true
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("connection registered with nick %q", uc.nick)
+
+ if uc.network.channels.Len() > 0 {
+ var channels, keys []string
+ for _, entry := range uc.network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, ch.Name)
+ keys = append(keys, ch.Key)
+ }
+
+ for _, msg := range join(channels, keys) {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+ case irc.RPL_MYINFO:
+ if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {
+ return err
+ }
+ case irc.RPL_ISUPPORT:
+ if err := parseMessageParams(msg, nil, nil); err != nil {
+ return err
+ }
+
+ var downstreamIsupport []string
+ for _, token := range msg.Params[1 : len(msg.Params)-1] {
+ parameter := token
+ var negate, hasValue bool
+ var value string
+ if strings.HasPrefix(token, "-") {
+ negate = true
+ token = token[1:]
+ } else if i := strings.IndexByte(token, '='); i >= 0 {
+ parameter = token[:i]
+ value = token[i+1:]
+ hasValue = true
+ }
+
+ if hasValue {
+ uc.isupport[parameter] = &value
+ } else if !negate {
+ uc.isupport[parameter] = nil
+ } else {
+ delete(uc.isupport, parameter)
+ }
+
+ var err error
+ switch parameter {
+ case "CASEMAPPING":
+ casemap, ok := parseCasemappingToken(value)
+ if !ok {
+ casemap = casemapRFC1459
+ }
+ uc.network.updateCasemapping(casemap)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.casemapIsSet = true
+ case "CHANMODES":
+ if !negate {
+ err = uc.handleChanModes(value)
+ } else {
+ uc.availableChannelModes = stdChannelModes
+ }
+ case "CHANTYPES":
+ if !negate {
+ uc.availableChannelTypes = value
+ } else {
+ uc.availableChannelTypes = stdChannelTypes
+ }
+ case "PREFIX":
+ if !negate {
+ err = uc.handleMemberships(value)
+ } else {
+ uc.availableMemberships = stdMemberships
+ }
+ }
+ if err != nil {
+ return err
+ }
+
+ if passthroughIsupport[parameter] {
+ downstreamIsupport = append(downstreamIsupport, token)
+ }
+ }
+
+ uc.updateMonitor()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil {
+ return
+ }
+ msgs := generateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport)
+ for _, msg := range msgs {
+ dc.SendMessage(msg)
+ }
+ })
+ case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD:
+ if !uc.casemapIsSet {
+ // upstream did not send any CASEMAPPING token, thus
+ // we assume it implements the old RFCs with rfc1459.
+ uc.casemapIsSet = true
+ uc.network.updateCasemapping(casemapRFC1459)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ if !uc.gotMotd {
+ // Ignore the initial MOTD upon connection, but forward
+ // subsequent MOTD messages downstream
+ uc.gotMotd = true
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case "BATCH":
+ var tag string
+ if err := parseMessageParams(msg, &tag); err != nil {
+ return err
+ }
+
+ if strings.HasPrefix(tag, "+") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; ok {
+ return fmt.Errorf("unexpected BATCH reference tag: batch was already defined: %q", tag)
+ }
+ var batchType string
+ if err := parseMessageParams(msg, nil, &batchType); err != nil {
+ return err
+ }
+ label := label
+ if label == "" && msgBatch != nil {
+ label = msgBatch.Label
+ }
+ uc.batches[tag] = batch{
+ Type: batchType,
+ Params: msg.Params[2:],
+ Outer: msgBatch,
+ Label: label,
+ }
+ } else if strings.HasPrefix(tag, "-") {
+ tag = tag[1:]
+ if _, ok := uc.batches[tag]; !ok {
+ return fmt.Errorf("unknown BATCH reference tag: %q", tag)
+ }
+ delete(uc.batches, tag)
+ } else {
+ return fmt.Errorf("unexpected BATCH reference tag: missing +/- prefix: %q", tag)
+ }
+ case "NICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newNick string
+ if err := parseMessageParams(msg, &newNick); err != nil {
+ return err
+ }
+
+ me := false
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
+ me = true
+ uc.nick = newNick
+ uc.nickCM = uc.network.casemap(uc.nick)
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ memberships := ch.Members.Value(msg.Prefix.Name)
+ if memberships != nil {
+ ch.Members.Delete(msg.Prefix.Name)
+ ch.Members.SetValue(newNick, memberships)
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if !me {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateNick()
+ })
+ uc.updateMonitor()
+ }
+ case "SETNAME":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var newRealname string
+ if err := parseMessageParams(msg, &newRealname); err != nil {
+ return err
+ }
+
+ // TODO: consider appending this message to logs
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("changed realname from %q to %q", uc.realname, newRealname)
+ uc.realname = newRealname
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateRealname()
+ })
+ } else {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case "JOIN":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("joined channel %q", ch)
+ members := membershipsCasemapMap{newCasemapMap(0)}
+ members.casemap = uc.network.casemap
+ uc.channels.SetValue(ch, &upstreamChannel{
+ Name: ch,
+ conn: uc,
+ Members: members,
+ })
+ uc.updateChannelAutoDetach(ch)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MODE",
+ Params: []string{ch},
+ })
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.SetValue(msg.Prefix.Name, &memberships{})
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "PART":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channels string
+ if err := parseMessageParams(msg, &channels); err != nil {
+ return err
+ }
+
+ for _, ch := range strings.Split(channels, ",") {
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("parted channel %q", ch)
+ uch := uc.channels.Value(ch)
+ if uch != nil {
+ uc.channels.Delete(ch)
+ uch.updateAutoDetach(0)
+ }
+ } else {
+ ch, err := uc.getChannel(ch)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(msg.Prefix.Name)
+ }
+
+ chMsg := msg.Copy()
+ chMsg.Params[0] = ch
+ uc.produce(ch, chMsg, nil)
+ }
+ case "KICK":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var channel, user string
+ if err := parseMessageParams(msg, &channel, &user); err != nil {
+ return err
+ }
+
+ if uc.isOurNick(user) {
+ uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
+ uc.channels.Delete(channel)
+ } else {
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+ ch.Members.Delete(user)
+ }
+
+ uc.produce(channel, msg, nil)
+ case "QUIT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ if uc.isOurNick(msg.Prefix.Name) {
+ uc.logger.Printf("quit")
+ }
+
+ for _, entry := range uc.channels.innerMap {
+ ch := entry.value.(*upstreamChannel)
+ if ch.Members.Has(msg.Prefix.Name) {
+ ch.Members.Delete(msg.Prefix.Name)
+
+ uc.appendLog(ch.Name, msg)
+ }
+ }
+
+ if msg.Prefix.Name != uc.nick {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(dc.marshalMessage(msg, uc.network))
+ })
+ }
+ case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
+ var name, topic string
+ if err := parseMessageParams(msg, nil, &name, &topic); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if msg.Command == irc.RPL_TOPIC {
+ ch.Topic = topic
+ } else {
+ ch.Topic = ""
+ }
+ case "TOPIC":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ var name string
+ if err := parseMessageParams(msg, &name); err != nil {
+ return err
+ }
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+ if len(msg.Params) > 1 {
+ ch.Topic = msg.Params[1]
+ ch.TopicWho = msg.Prefix.Copy()
+ ch.TopicTime = time.Now() // TODO use msg.Tags["time"]
+ } else {
+ ch.Topic = ""
+ }
+ uc.produce(ch.Name, msg, nil)
+ case "MODE":
+ var name, modeStr string
+ if err := parseMessageParams(msg, &name, &modeStr); err != nil {
+ return err
+ }
+
+ if !uc.isChannel(name) { // user mode change
+ if name != uc.nick {
+ return fmt.Errorf("received MODE message for unknown nick %q", name)
+ }
+
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ } else { // channel mode change
+ ch, err := uc.getChannel(name)
+ if err != nil {
+ return err
+ }
+
+ needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:])
+ if err != nil {
+ return err
+ }
+
+ uc.appendLog(ch.Name, msg)
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := make([]string, len(msg.Params))
+ params[0] = dc.marshalEntity(uc.network, name)
+ params[1] = modeStr
+
+ copy(params[2:], msg.Params[2:])
+ for i, modeParam := range params[2:] {
+ if _, ok := needMarshaling[i]; ok {
+ params[2+i] = dc.marshalEntity(uc.network, modeParam)
+ }
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "MODE",
+ Params: params,
+ })
+ })
+ }
+ }
+ case irc.RPL_UMODEIS:
+ if err := parseMessageParams(msg, nil); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 1 {
+ modeStr = msg.Params[1]
+ }
+
+ uc.modes = ""
+ if err := uc.modes.Apply(modeStr); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if dc.upstream() == nil {
+ return
+ }
+
+ dc.SendMessage(msg)
+ })
+ case irc.RPL_CHANNELMODEIS:
+ var channel string
+ if err := parseMessageParams(msg, nil, &channel); err != nil {
+ return err
+ }
+ modeStr := ""
+ if len(msg.Params) > 2 {
+ modeStr = msg.Params[2]
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstMode := ch.modes == nil
+ ch.modes = make(map[byte]string)
+ if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil {
+ return err
+ }
+
+ c := uc.network.channels.Value(channel)
+ if firstMode && (c == nil || !c.Detached) {
+ modeStr, modeParams := ch.modes.Format()
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr}
+ params = append(params, modeParams...)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_CHANNELMODEIS,
+ Params: params,
+ })
+ })
+ }
+ case rpl_creationtime:
+ var channel, creationTime string
+ if err := parseMessageParams(msg, nil, &channel, &creationTime); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstCreationTime := ch.creationTime == ""
+ ch.creationTime = creationTime
+
+ c := uc.network.channels.Value(channel)
+ if firstCreationTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_creationtime,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, ch.Name), creationTime},
+ })
+ })
+ }
+ case rpl_topicwhotime:
+ var channel, who, timeStr string
+ if err := parseMessageParams(msg, nil, &channel, &who, &timeStr); err != nil {
+ return err
+ }
+
+ ch, err := uc.getChannel(channel)
+ if err != nil {
+ return err
+ }
+
+ firstTopicWhoTime := ch.TopicWho == nil
+ ch.TopicWho = irc.ParsePrefix(who)
+ sec, err := strconv.ParseInt(timeStr, 10, 64)
+ if err != nil {
+ return fmt.Errorf("failed to parse topic time: %v", err)
+ }
+ ch.TopicTime = time.Unix(sec, 0)
+
+ c := uc.network.channels.Value(channel)
+ if firstTopicWhoTime && (c == nil || !c.Detached) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ topicWho := dc.marshalUserPrefix(uc.network, ch.TopicWho)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: rpl_topicwhotime,
+ Params: []string{
+ dc.nick,
+ dc.marshalEntity(uc.network, ch.Name),
+ topicWho.String(),
+ timeStr,
+ },
+ })
+ })
+ }
+ case irc.RPL_LIST:
+ var channel, clients, topic string
+ if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LIST,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic},
+ })
+ case irc.RPL_LISTEND:
+ dc, cmd := uc.dequeueCommand("LIST")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST")
+ } else if dc == nil {
+ return nil
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_LISTEND,
+ Params: []string{dc.nick, "End of /LIST"},
+ })
+ case irc.RPL_NAMREPLY:
+ var name, statusStr, members string
+ if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+ members := splitSpace(members)
+ for i, member := range members {
+ memberships, nick := uc.parseMembershipPrefix(member)
+ members[i] = memberships.Format(dc) + dc.marshalEntity(uc.network, nick)
+ }
+ memberStr := strings.Join(members, " ")
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_NAMREPLY,
+ Params: []string{dc.nick, statusStr, channel, memberStr},
+ })
+ })
+ return nil
+ }
+
+ status, err := parseChannelStatus(statusStr)
+ if err != nil {
+ return err
+ }
+ ch.Status = status
+
+ for _, s := range splitSpace(members) {
+ memberships, nick := uc.parseMembershipPrefix(s)
+ ch.Members.SetValue(nick, memberships)
+ }
+ case irc.RPL_ENDOFNAMES:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ ch := uc.channels.Value(name)
+ if ch == nil {
+ // NAMES on a channel we have not joined, forward to downstream
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, name)
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFNAMES,
+ Params: []string{dc.nick, channel, "End of /NAMES list"},
+ })
+ })
+ return nil
+ }
+
+ if ch.complete {
+ return fmt.Errorf("received unexpected RPL_ENDOFNAMES")
+ }
+ ch.complete = true
+
+ c := uc.network.channels.Value(name)
+ if c == nil || !c.Detached {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ forwardChannel(ctx, dc, ch)
+ })
+ }
+ case irc.RPL_WHOREPLY:
+ var channel, username, host, server, nick, flags, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &username, &host, &server, &nick, &flags, &trailing); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ if channel != "*" {
+ channel = dc.marshalEntity(uc.network, channel)
+ }
+ nick = dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOREPLY,
+ Params: []string{dc.nick, channel, username, host, server, nick, flags, trailing},
+ })
+ case rpl_whospcrpl:
+ dc, cmd := uc.currentPendingCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ // Only supported in single-upstream mode, so forward as-is
+ dc.SendMessage(msg)
+ case irc.RPL_ENDOFWHO:
+ var name string
+ if err := parseMessageParams(msg, nil, &name); err != nil {
+ return err
+ }
+
+ dc, cmd := uc.dequeueCommand("WHO")
+ if cmd == nil {
+ return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO")
+ } else if dc == nil {
+ return nil
+ }
+
+ mask := "*"
+ if len(cmd.Params) > 0 {
+ mask = cmd.Params[0]
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHO,
+ Params: []string{dc.nick, mask, "End of /WHO list"},
+ })
+ case irc.RPL_WHOISUSER:
+ var nick, username, host, realname string
+ if err := parseMessageParams(msg, nil, &nick, &username, &host, nil, &realname); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISUSER,
+ Params: []string{dc.nick, nick, username, host, "*", realname},
+ })
+ })
+ case irc.RPL_WHOISSERVER:
+ var nick, server, serverInfo string
+ if err := parseMessageParams(msg, nil, &nick, &server, &serverInfo); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISSERVER,
+ Params: []string{dc.nick, nick, server, serverInfo},
+ })
+ })
+ case irc.RPL_WHOISOPERATOR:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISOPERATOR,
+ Params: []string{dc.nick, nick, "is an IRC operator"},
+ })
+ })
+ case irc.RPL_WHOISIDLE:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick, nil); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ params := []string{dc.nick, nick}
+ params = append(params, msg.Params[2:]...)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISIDLE,
+ Params: params,
+ })
+ })
+ case irc.RPL_WHOISCHANNELS:
+ var nick, channelList string
+ if err := parseMessageParams(msg, nil, &nick, &channelList); err != nil {
+ return err
+ }
+ channels := splitSpace(channelList)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ channelList := make([]string, len(channels))
+ for i, channel := range channels {
+ prefix, channel := uc.parseMembershipPrefix(channel)
+ channel = dc.marshalEntity(uc.network, channel)
+ channelList[i] = prefix.Format(dc) + channel
+ }
+ channels := strings.Join(channelList, " ")
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_WHOISCHANNELS,
+ Params: []string{dc.nick, nick, channels},
+ })
+ })
+ case irc.RPL_ENDOFWHOIS:
+ var nick string
+ if err := parseMessageParams(msg, nil, &nick); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ nick := dc.marshalEntity(uc.network, nick)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_ENDOFWHOIS,
+ Params: []string{dc.nick, nick, "End of /WHOIS list"},
+ })
+ })
+ case "INVITE":
+ var nick, channel string
+ if err := parseMessageParams(msg, &nick, &channel); err != nil {
+ return err
+ }
+
+ weAreInvited := uc.isOurNick(nick)
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !weAreInvited && !dc.caps["invite-notify"] {
+ return
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: "INVITE",
+ Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_INVITING:
+ var nick, channel string
+ if err := parseMessageParams(msg, nil, &nick, &channel); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_INVITING,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
+ })
+ })
+ case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE:
+ var targetsStr string
+ if err := parseMessageParams(msg, nil, &targetsStr); err != nil {
+ return err
+ }
+ targets := strings.Split(targetsStr, ",")
+
+ online := msg.Command == irc.RPL_MONONLINE
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ uc.monitored.SetValue(prefix.Name, online)
+ }
+
+ // Check if the nick we want is now free
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if !online && uc.nickCM != wantNickCM {
+ found := false
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if uc.network.casemap(prefix.Name) == wantNickCM {
+ found = true
+ break
+ }
+ }
+ if found {
+ uc.logger.Printf("desired nick %q is now available", wantNick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{wantNick},
+ })
+ }
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ prefix := irc.ParsePrefix(target)
+ if dc.monitored.Has(prefix.Name) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, target},
+ })
+ }
+ }
+ })
+ case irc.ERR_MONLISTFULL:
+ var limit, targetsStr string
+ if err := parseMessageParams(msg, nil, &limit, &targetsStr); err != nil {
+ return err
+ }
+
+ targets := strings.Split(targetsStr, ",")
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for _, target := range targets {
+ if dc.monitored.Has(target) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, limit, target},
+ })
+ }
+ }
+ })
+ case irc.RPL_AWAY:
+ var nick, reason string
+ if err := parseMessageParams(msg, nil, &nick, &reason); err != nil {
+ return err
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: irc.RPL_AWAY,
+ Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), reason},
+ })
+ })
+ case "AWAY", "ACCOUNT":
+ if msg.Prefix == nil {
+ return fmt.Errorf("expected a prefix")
+ }
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST:
+ var channel, mask string
+ if err := parseMessageParams(msg, nil, &channel, &mask); err != nil {
+ return err
+ }
+ var addNick, addTime string
+ if len(msg.Params) >= 5 {
+ addNick = msg.Params[3]
+ addTime = msg.Params[4]
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ channel := dc.marshalEntity(uc.network, channel)
+
+ var params []string
+ if addNick != "" && addTime != "" {
+ addNick := dc.marshalEntity(uc.network, addNick)
+ params = []string{dc.nick, channel, mask, addNick, addTime}
+ } else {
+ params = []string{dc.nick, channel, mask}
+ }
+
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ case irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST:
+ var channel, trailing string
+ if err := parseMessageParams(msg, nil, &channel, &trailing); err != nil {
+ return err
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ upstreamChannel := dc.marshalEntity(uc.network, channel)
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, upstreamChannel, trailing},
+ })
+ })
+ case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN:
+ var command, reason string
+ if err := parseMessageParams(msg, nil, &command, &reason); err != nil {
+ return err
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: []string{dc.nick, command, reason},
+ })
+ })
+ case "FAIL":
+ var command, code string
+ if err := parseMessageParams(msg, &command, &code); err != nil {
+ return err
+ }
+
+ if !uc.registered && command == "*" && code == "ACCOUNT_REQUIRED" {
+ return registrationError{msg}
+ }
+
+ if dc, _ := uc.dequeueCommand(command); dc != nil && downstreamID == 0 {
+ downstreamID = dc.id
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case "ACK":
+ // Ignore
+ case irc.RPL_NOWAWAY, irc.RPL_UNAWAY:
+ // Ignore
+ case irc.RPL_YOURHOST, irc.RPL_CREATED:
+ // Ignore
+ case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
+ fallthrough
+ case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
+ fallthrough
+ case rpl_localusers, rpl_globalusers:
+ fallthrough
+ case irc.RPL_MOTDSTART, irc.RPL_MOTD:
+ // Ignore these messages if they're part of the initial registration
+ // message burst. Forward them if the user explicitly asked for them.
+ if !uc.gotMotd {
+ return nil
+ }
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: msg.Params,
+ })
+ })
+ case irc.RPL_LISTSTART:
+ // Ignore
+ case "ERROR":
+ var text string
+ if err := parseMessageParams(msg, &text); err != nil {
+ return err
+ }
+ return fmt.Errorf("fatal server error: %v", text)
+ case irc.ERR_NICKNAMEINUSE:
+ // At this point, we haven't received ISUPPORT so we don't know the
+ // maximum nickname length or whether the server supports MONITOR. Many
+ // servers have NICKLEN=30 so let's just use that.
+ if !uc.registered && len(uc.nick)+1 < 30 {
+ uc.nick = uc.nick + "_"
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick)
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ return nil
+ }
+ fallthrough
+ case irc.ERR_PASSWDMISMATCH, irc.ERR_ERRONEUSNICKNAME, irc.ERR_NICKCOLLISION, irc.ERR_UNAVAILRESOURCE, irc.ERR_NOPERMFORHOST, irc.ERR_YOUREBANNEDCREEP:
+ if !uc.registered {
+ return registrationError{msg}
+ }
+ fallthrough
+ default:
+ uc.logger.Printf("unhandled message: %v", msg)
+
+ uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
+ // best effort marshaling for unknown messages, replies and errors:
+ // most numerics start with the user nick, marshal it if that's the case
+ // otherwise, conservately keep the params without marshaling
+ params := msg.Params
+ if _, err := strconv.Atoi(msg.Command); err == nil { // numeric
+ if len(msg.Params) > 0 && isOurNick(uc.network, msg.Params[0]) {
+ params[0] = dc.nick
+ }
+ }
+ dc.SendMessage(&irc.Message{
+ Prefix: uc.srv.prefix(),
+ Command: msg.Command,
+ Params: params,
+ })
+ })
+ }
+ return nil
+}
+
+func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) {
+ if uc.network.detachedMessageNeedsRelay(ch, msg) {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.relayDetachedMessage(uc.network, msg)
+ })
+ }
+ if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
+ uc.network.attach(ctx, ch)
+ if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
+ uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
+ }
+ }
+}
+
+func (uc *upstreamConn) handleChanModes(s string) error {
+ parts := strings.SplitN(s, ",", 5)
+ if len(parts) < 4 {
+ return fmt.Errorf("malformed ISUPPORT CHANMODES value: %v", s)
+ }
+ modes := make(map[byte]channelModeType)
+ for i, mt := range []channelModeType{modeTypeA, modeTypeB, modeTypeC, modeTypeD} {
+ for j := 0; j < len(parts[i]); j++ {
+ mode := parts[i][j]
+ modes[mode] = mt
+ }
+ }
+ uc.availableChannelModes = modes
+ return nil
+}
+
+func (uc *upstreamConn) handleMemberships(s string) error {
+ if s == "" {
+ uc.availableMemberships = nil
+ return nil
+ }
+
+ if s[0] != '(' {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ sep := strings.IndexByte(s, ')')
+ if sep < 0 || len(s) != sep*2 {
+ return fmt.Errorf("malformed ISUPPORT PREFIX value: %v", s)
+ }
+ memberships := make([]membership, len(s)/2-1)
+ for i := range memberships {
+ memberships[i] = membership{
+ Mode: s[i+1],
+ Prefix: s[sep+i+1],
+ }
+ }
+ uc.availableMemberships = memberships
+ return nil
+}
+
+func (uc *upstreamConn) handleSupportedCaps(capsStr string) {
+ caps := strings.Fields(capsStr)
+ for _, s := range caps {
+ kv := strings.SplitN(s, "=", 2)
+ k := strings.ToLower(kv[0])
+ var v string
+ if len(kv) == 2 {
+ v = kv[1]
+ }
+ uc.supportedCaps[k] = v
+ }
+}
+
+func (uc *upstreamConn) requestCaps() {
+ var requestCaps []string
+ for c := range permanentUpstreamCaps {
+ if _, ok := uc.supportedCaps[c]; ok && !uc.caps[c] {
+ requestCaps = append(requestCaps, c)
+ }
+ }
+
+ if len(requestCaps) == 0 {
+ return
+ }
+
+ uc.SendMessage(context.TODO(), &irc.Message{
+ Command: "CAP",
+ Params: []string{"REQ", strings.Join(requestCaps, " ")},
+ })
+}
+
+func (uc *upstreamConn) supportsSASL(mech string) bool {
+ v, ok := uc.supportedCaps["sasl"]
+ if !ok {
+ return false
+ }
+
+ if v == "" {
+ return true
+ }
+
+ mechanisms := strings.Split(v, ",")
+ for _, mech := range mechanisms {
+ if strings.EqualFold(mech, mech) {
+ return true
+ }
+ }
+ return false
+}
+
+func (uc *upstreamConn) requestSASL() bool {
+ if uc.network.SASL.Mechanism == "" {
+ return false
+ }
+ return uc.supportsSASL(uc.network.SASL.Mechanism)
+}
+
+func (uc *upstreamConn) handleCapAck(ctx context.Context, name string, ok bool) error {
+ uc.caps[name] = ok
+
+ switch name {
+ case "sasl":
+ if !uc.requestSASL() {
+ return nil
+ }
+ if !ok {
+ uc.logger.Printf("server refused to acknowledge the SASL capability")
+ return nil
+ }
+
+ auth := &uc.network.SASL
+ switch auth.Mechanism {
+ case "PLAIN":
+ uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
+ uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
+ case "EXTERNAL":
+ uc.logger.Printf("starting SASL EXTERNAL authentication")
+ uc.saslClient = sasl.NewExternalClient("")
+ default:
+ return fmt.Errorf("unsupported SASL mechanism %q", name)
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AUTHENTICATE",
+ Params: []string{auth.Mechanism},
+ })
+ default:
+ if permanentUpstreamCaps[name] {
+ break
+ }
+ uc.logger.Printf("received CAP ACK/NAK for a cap we don't support: %v", name)
+ }
+ return nil
+}
+
+func splitSpace(s string) []string {
+ return strings.FieldsFunc(s, func(r rune) bool {
+ return r == ' '
+ })
+}
+
+func (uc *upstreamConn) register(ctx context.Context) {
+ uc.nick = GetNick(&uc.user.User, &uc.network.Network)
+ uc.nickCM = uc.network.casemap(uc.nick)
+ uc.username = GetUsername(&uc.user.User, &uc.network.Network)
+ uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "CAP",
+ Params: []string{"LS", "302"},
+ })
+
+ if uc.network.Pass != "" {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "PASS",
+ Params: []string{uc.network.Pass},
+ })
+ }
+
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "NICK",
+ Params: []string{uc.nick},
+ })
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "USER",
+ Params: []string{uc.username, "0", "*", uc.realname},
+ })
+}
+
+func (uc *upstreamConn) ReadMessage() (*irc.Message, error) {
+ msg, err := uc.conn.ReadMessage()
+ if err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+func (uc *upstreamConn) runUntilRegistered(ctx context.Context) error {
+ for !uc.registered {
+ msg, err := uc.ReadMessage()
+ if err != nil {
+ return fmt.Errorf("failed to read message: %v", err)
+ }
+
+ if err := uc.handleMessage(ctx, msg); err != nil {
+ if _, ok := err.(registrationError); ok {
+ return err
+ } else {
+ msg.Tags = nil // prevent message tags from cluttering logs
+ return fmt.Errorf("failed to handle message %q: %v", msg, err)
+ }
+ }
+ }
+
+ for _, command := range uc.network.ConnectCommands {
+ m, err := irc.ParseMessage(command)
+ if err != nil {
+ uc.logger.Printf("failed to parse connect command %q: %v", command, err)
+ } else {
+ uc.SendMessage(ctx, m)
+ }
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) readMessages(ch chan<- event) error {
+ for {
+ msg, err := uc.ReadMessage()
+ if errors.Is(err, io.EOF) {
+ break
+ } else if err != nil {
+ return fmt.Errorf("failed to read IRC command: %v", err)
+ }
+
+ ch <- eventUpstreamMessage{msg, uc}
+ }
+
+ return nil
+}
+
+func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
+ if !uc.caps["message-tags"] {
+ msg = msg.Copy()
+ msg.Tags = nil
+ }
+
+ uc.conn.SendMessage(ctx, msg)
+}
+
+func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
+ if uc.caps["labeled-response"] {
+ if msg.Tags == nil {
+ msg.Tags = make(map[string]irc.TagValue)
+ }
+ msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
+ uc.nextLabelID++
+ }
+ uc.SendMessage(ctx, msg)
+}
+
+// appendLog appends a message to the log file.
+//
+// The internal message ID is returned. If the message isn't recorded in the
+// log file, an empty string is returned.
+func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string) {
+ if uc.user.msgStore == nil {
+ return ""
+ }
+
+ // Don't store messages with a server mask target
+ if strings.HasPrefix(entity, "$") {
+ return ""
+ }
+
+ entityCM := uc.network.casemap(entity)
+ if entityCM == "nickserv" {
+ // The messages sent/received from NickServ may contain
+ // security-related information (like passwords). Don't store these.
+ return ""
+ }
+
+ if !uc.network.delivered.HasTarget(entity) {
+ // This is the first message we receive from this target. Save the last
+ // message ID in delivery receipts, so that we can send the new message
+ // in the backlog if an offline client reconnects.
+ lastID, err := uc.user.msgStore.LastMsgID(&uc.network.Network, entityCM, time.Now())
+ if err != nil {
+ uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
+ return ""
+ }
+
+ uc.network.delivered.ForEachClient(func(clientName string) {
+ uc.network.delivered.StoreID(entity, clientName, lastID)
+ })
+ }
+
+ msgID, err := uc.user.msgStore.Append(&uc.network.Network, entityCM, msg)
+ if err != nil {
+ uc.logger.Printf("failed to append message to store: %v", err)
+ return ""
+ }
+
+ return msgID
+}
+
+// produce appends a message to the logs and forwards it to connected downstream
+// connections.
+//
+// If origin is not nil and origin doesn't support echo-message, the message is
+// forwarded to all connections except origin.
+func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstreamConn) {
+ var msgID string
+ if target != "" {
+ msgID = uc.appendLog(target, msg)
+ }
+
+ // Don't forward messages if it's a detached channel
+ ch := uc.network.channels.Value(target)
+ detached := ch != nil && ch.Detached
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !detached && (dc != origin || dc.caps["echo-message"]) {
+ dc.sendMessageWithID(dc.marshalMessage(msg, uc.network), msgID)
+ } else {
+ dc.advanceMessageWithID(msg, msgID)
+ }
+ })
+}
+
+func (uc *upstreamConn) updateAway() {
+ ctx := context.TODO()
+
+ away := true
+ uc.forEachDownstream(func(*downstreamConn) {
+ away = false
+ })
+ if away == uc.away {
+ return
+ }
+ if away {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ Params: []string{"Auto away"},
+ })
+ } else {
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "AWAY",
+ })
+ }
+ uc.away = away
+}
+
+func (uc *upstreamConn) updateChannelAutoDetach(name string) {
+ uch := uc.channels.Value(name)
+ if uch == nil {
+ return
+ }
+ ch := uc.network.channels.Value(name)
+ if ch == nil || ch.Detached {
+ return
+ }
+ uch.updateAutoDetach(ch.DetachAfter)
+}
+
+func (uc *upstreamConn) updateMonitor() {
+ if _, ok := uc.isupport["MONITOR"]; !ok {
+ return
+ }
+
+ ctx := context.TODO()
+
+ add := make(map[string]struct{})
+ var addList []string
+ seen := make(map[string]struct{})
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ for targetCM := range dc.monitored.innerMap {
+ if !uc.monitored.Has(targetCM) {
+ if _, ok := add[targetCM]; !ok {
+ addList = append(addList, targetCM)
+ add[targetCM] = struct{}{}
+ }
+ } else {
+ seen[targetCM] = struct{}{}
+ }
+ }
+ })
+
+ wantNick := GetNick(&uc.user.User, &uc.network.Network)
+ wantNickCM := uc.network.casemap(wantNick)
+ if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
+ addList = append(addList, wantNickCM)
+ add[wantNickCM] = struct{}{}
+ }
+
+ removeAll := true
+ var removeList []string
+ for targetCM, entry := range uc.monitored.innerMap {
+ if _, ok := seen[targetCM]; ok {
+ removeAll = false
+ } else {
+ removeList = append(removeList, entry.originalKey)
+ }
+ }
+
+ // TODO: better handle the case where len(uc.monitored) + len(addList)
+ // exceeds the limit, probably by immediately sending ERR_MONLISTFULL?
+
+ if removeAll && len(addList) == 0 && len(removeList) > 0 {
+ // Optimization when the last MONITOR-aware downstream disconnects
+ uc.SendMessage(ctx, &irc.Message{
+ Command: "MONITOR",
+ Params: []string{"C"},
+ })
+ } else {
+ msgs := generateMonitor("-", removeList)
+ msgs = append(msgs, generateMonitor("+", addList)...)
+ for _, msg := range msgs {
+ uc.SendMessage(ctx, msg)
+ }
+ }
+
+ for _, target := range removeList {
+ uc.monitored.Delete(target)
+ }
+}
--- /dev/null
+package suika
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/binary"
+ "encoding/hex"
+ "fmt"
+ "math/big"
+ "net"
+ "sort"
+ "strings"
+ "time"
+
+ "gopkg.in/irc.v3"
+)
+
+type event interface{}
+
+type eventUpstreamMessage struct {
+ msg *irc.Message
+ uc *upstreamConn
+}
+
+type eventUpstreamConnectionError struct {
+ net *network
+ err error
+}
+
+type eventUpstreamConnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamDisconnected struct {
+ uc *upstreamConn
+}
+
+type eventUpstreamError struct {
+ uc *upstreamConn
+ err error
+}
+
+type eventDownstreamMessage struct {
+ msg *irc.Message
+ dc *downstreamConn
+}
+
+type eventDownstreamConnected struct {
+ dc *downstreamConn
+}
+
+type eventDownstreamDisconnected struct {
+ dc *downstreamConn
+}
+
+type eventChannelDetach struct {
+ uc *upstreamConn
+ name string
+}
+
+type eventBroadcast struct {
+ msg *irc.Message
+}
+
+type eventStop struct{}
+
+type eventUserUpdate struct {
+ password *string
+ admin *bool
+ done chan error
+}
+
+type deliveredClientMap map[string]string // client name -> msg ID
+
+type deliveredStore struct {
+ m deliveredCasemapMap
+}
+
+func newDeliveredStore() deliveredStore {
+ return deliveredStore{deliveredCasemapMap{newCasemapMap(0)}}
+}
+
+func (ds deliveredStore) HasTarget(target string) bool {
+ return ds.m.Value(target) != nil
+}
+
+func (ds deliveredStore) LoadID(target, clientName string) string {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ return ""
+ }
+ return clients[clientName]
+}
+
+func (ds deliveredStore) StoreID(target, clientName, msgID string) {
+ clients := ds.m.Value(target)
+ if clients == nil {
+ clients = make(deliveredClientMap)
+ ds.m.SetValue(target, clients)
+ }
+ clients[clientName] = msgID
+}
+
+func (ds deliveredStore) ForEachTarget(f func(target string)) {
+ for _, entry := range ds.m.innerMap {
+ f(entry.originalKey)
+ }
+}
+
+func (ds deliveredStore) ForEachClient(f func(clientName string)) {
+ clients := make(map[string]struct{})
+ for _, entry := range ds.m.innerMap {
+ delivered := entry.value.(deliveredClientMap)
+ for clientName := range delivered {
+ clients[clientName] = struct{}{}
+ }
+ }
+
+ for clientName := range clients {
+ f(clientName)
+ }
+}
+
+type network struct {
+ Network
+ user *user
+ logger Logger
+ stopped chan struct{}
+
+ conn *upstreamConn
+ channels channelCasemapMap
+ delivered deliveredStore
+ lastError error
+ casemap casemapping
+}
+
+func newNetwork(user *user, record *Network, channels []Channel) *network {
+ logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
+
+ m := channelCasemapMap{newCasemapMap(0)}
+ for _, ch := range channels {
+ ch := ch
+ m.SetValue(ch.Name, &ch)
+ }
+
+ return &network{
+ Network: *record,
+ user: user,
+ logger: logger,
+ stopped: make(chan struct{}),
+ channels: m,
+ delivered: newDeliveredStore(),
+ casemap: casemapRFC1459,
+ }
+}
+
+func (net *network) forEachDownstream(f func(*downstreamConn)) {
+ net.user.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network == nil && !dc.isMultiUpstream {
+ return
+ }
+ if dc.network != nil && dc.network != net {
+ return
+ }
+ f(dc)
+ })
+}
+
+func (net *network) isStopped() bool {
+ select {
+ case <-net.stopped:
+ return true
+ default:
+ return false
+ }
+}
+
+func userIdent(u *User) string {
+ // The ident is a string we will send to upstream servers in clear-text.
+ // For privacy reasons, make sure it doesn't expose any meaningful user
+ // metadata. We just use the base64-encoded hashed ID, so that people don't
+ // start relying on the string being an integer or following a pattern.
+ var b [64]byte
+ binary.LittleEndian.PutUint64(b[:], uint64(u.ID))
+ h := sha256.Sum256(b[:])
+ return hex.EncodeToString(h[:16])
+}
+
+func (net *network) run() {
+ if !net.Enabled {
+ return
+ }
+
+ var lastTry time.Time
+ backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter)
+ for {
+ if net.isStopped() {
+ return
+ }
+
+ delay := backoff.Next() - time.Now().Sub(lastTry)
+ if delay > 0 {
+ net.logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
+ time.Sleep(delay)
+ }
+ lastTry = time.Now()
+
+
+ uc, err := connectToUpstream(context.TODO(), net)
+ if err != nil {
+ net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
+ continue
+ }
+
+ uc.register(context.TODO())
+ if err := uc.runUntilRegistered(context.TODO()); err != nil {
+ text := err.Error()
+ temp := true
+ if regErr, ok := err.(registrationError); ok {
+ text = regErr.Reason()
+ temp = regErr.Temporary()
+ }
+ uc.logger.Printf("failed to register: %v", text)
+ net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)}
+ uc.Close()
+ if !temp {
+ return
+ }
+ continue
+ }
+
+ // TODO: this is racy with net.stopped. If the network is stopped
+ // before the user goroutine receives eventUpstreamConnected, the
+ // connection won't be closed.
+ net.user.events <- eventUpstreamConnected{uc}
+ if err := uc.readMessages(net.user.events); err != nil {
+ uc.logger.Printf("failed to handle messages: %v", err)
+ net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)}
+ }
+ uc.Close()
+ net.user.events <- eventUpstreamDisconnected{uc}
+
+ backoff.Reset()
+ }
+}
+
+func (net *network) stop() {
+ if !net.isStopped() {
+ close(net.stopped)
+ }
+
+ if net.conn != nil {
+ net.conn.Close()
+ }
+}
+
+func (net *network) detach(ch *Channel) {
+ if ch.Detached {
+ return
+ }
+
+ net.logger.Printf("detaching channel %q", ch.Name)
+
+ ch.Detached = true
+
+ if net.user.msgStore != nil {
+ nameCM := net.casemap(ch.Name)
+ lastID, err := net.user.msgStore.LastMsgID(&net.Network, nameCM, time.Now())
+ if err != nil {
+ net.logger.Printf("failed to get last message ID for channel %q: %v", ch.Name, err)
+ }
+ ch.DetachedInternalMsgID = lastID
+ }
+
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "PART",
+ Params: []string{dc.marshalEntity(net, ch.Name), "Detach"},
+ })
+ })
+}
+
+func (net *network) attach(ctx context.Context, ch *Channel) {
+ if !ch.Detached {
+ return
+ }
+
+ net.logger.Printf("attaching channel %q", ch.Name)
+
+ detachedMsgID := ch.DetachedInternalMsgID
+ ch.Detached = false
+ ch.DetachedInternalMsgID = ""
+
+ var uch *upstreamChannel
+ if net.conn != nil {
+ uch = net.conn.channels.Value(ch.Name)
+
+ net.conn.updateChannelAutoDetach(ch.Name)
+ }
+
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.prefix(),
+ Command: "JOIN",
+ Params: []string{dc.marshalEntity(net, ch.Name)},
+ })
+
+ if uch != nil {
+ forwardChannel(ctx, dc, uch)
+ }
+
+ if detachedMsgID != "" {
+ dc.sendTargetBacklog(ctx, net, ch.Name, detachedMsgID)
+ }
+ })
+}
+
+func (net *network) deleteChannel(ctx context.Context, name string) error {
+ ch := net.channels.Value(name)
+ if ch == nil {
+ return fmt.Errorf("unknown channel %q", name)
+ }
+ if net.conn != nil {
+ uch := net.conn.channels.Value(ch.Name)
+ if uch != nil {
+ uch.updateAutoDetach(0)
+ }
+ }
+
+ if err := net.user.srv.db.DeleteChannel(ctx, ch.ID); err != nil {
+ return err
+ }
+ net.channels.Delete(name)
+ return nil
+}
+
+func (net *network) updateCasemapping(newCasemap casemapping) {
+ net.casemap = newCasemap
+ net.channels.SetCasemapping(newCasemap)
+ net.delivered.m.SetCasemapping(newCasemap)
+ if uc := net.conn; uc != nil {
+ uc.channels.SetCasemapping(newCasemap)
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.Members.SetCasemapping(newCasemap)
+ }
+ uc.monitored.SetCasemapping(newCasemap)
+ }
+ net.forEachDownstream(func(dc *downstreamConn) {
+ dc.monitored.SetCasemapping(newCasemap)
+ })
+}
+
+func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName string) {
+ if !net.user.hasPersistentMsgStore() {
+ return
+ }
+
+ var receipts []DeliveryReceipt
+ net.delivered.ForEachTarget(func(target string) {
+ msgID := net.delivered.LoadID(target, clientName)
+ if msgID == "" {
+ return
+ }
+ receipts = append(receipts, DeliveryReceipt{
+ Target: target,
+ InternalMsgID: msgID,
+ })
+ })
+
+ if err := net.user.srv.db.StoreClientDeliveryReceipts(ctx, net.ID, clientName, receipts); err != nil {
+ net.logger.Printf("failed to store delivery receipts for client %q: %v", clientName, err)
+ }
+}
+
+func (net *network) isHighlight(msg *irc.Message) bool {
+ if msg.Command != "PRIVMSG" && msg.Command != "NOTICE" {
+ return false
+ }
+
+ text := msg.Params[1]
+
+ nick := net.Nick
+ if net.conn != nil {
+ nick = net.conn.nick
+ }
+
+ // TODO: use case-mapping aware comparison here
+ return msg.Prefix.Name != nick && isHighlight(text, nick)
+}
+
+func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
+ highlight := net.isHighlight(msg)
+ return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
+}
+
+func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
+ // User may have e.g. EXTERNAL mechanism configured. We do not want to
+ // automatically erase the key pair or any other credentials.
+ if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
+ return
+ }
+
+ net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
+ net.SASL.Mechanism = "PLAIN"
+ net.SASL.Plain.Username = username
+ net.SASL.Plain.Password = password
+ if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
+ net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
+ }
+}
+
+type user struct {
+ User
+ srv *Server
+ logger Logger
+
+ events chan event
+ done chan struct{}
+
+ networks []*network
+ downstreamConns []*downstreamConn
+ msgStore messageStore
+}
+
+func newUser(srv *Server, record *User) *user {
+ logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
+
+ var msgStore messageStore
+ if logPath := srv.Config().LogPath; logPath != "" {
+ msgStore = newFSMessageStore(logPath, record)
+ } else {
+ msgStore = newMemoryMessageStore()
+ }
+
+ return &user{
+ User: *record,
+ srv: srv,
+ logger: logger,
+ events: make(chan event, 64),
+ done: make(chan struct{}),
+ msgStore: msgStore,
+ }
+}
+
+func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
+ for _, network := range u.networks {
+ if network.conn == nil {
+ continue
+ }
+ f(network.conn)
+ }
+}
+
+func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
+ for _, dc := range u.downstreamConns {
+ f(dc)
+ }
+}
+
+func (u *user) getNetwork(name string) *network {
+ for _, network := range u.networks {
+ if network.Addr == name {
+ return network
+ }
+ if network.Name != "" && network.Name == name {
+ return network
+ }
+ }
+ return nil
+}
+
+func (u *user) getNetworkByID(id int64) *network {
+ for _, net := range u.networks {
+ if net.ID == id {
+ return net
+ }
+ }
+ return nil
+}
+
+func (u *user) run() {
+ defer func() {
+ if u.msgStore != nil {
+ if err := u.msgStore.Close(); err != nil {
+ u.logger.Printf("failed to close message store for user %q: %v", u.Username, err)
+ }
+ }
+ close(u.done)
+ }()
+
+ networks, err := u.srv.db.ListNetworks(context.TODO(), u.ID)
+ if err != nil {
+ u.logger.Printf("failed to list networks for user %q: %v", u.Username, err)
+ return
+ }
+
+ sort.Slice(networks, func(i, j int) bool {
+ return networks[i].ID < networks[j].ID
+ })
+
+ for _, record := range networks {
+ record := record
+ channels, err := u.srv.db.ListChannels(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to list channels for user %q, network %q: %v", u.Username, record.GetName(), err)
+ continue
+ }
+
+ network := newNetwork(u, &record, channels)
+ u.networks = append(u.networks, network)
+
+ if u.hasPersistentMsgStore() {
+ receipts, err := u.srv.db.ListDeliveryReceipts(context.TODO(), record.ID)
+ if err != nil {
+ u.logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
+ return
+ }
+
+ for _, rcpt := range receipts {
+ network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
+ }
+ }
+
+ go network.run()
+ }
+
+ for e := range u.events {
+ switch e := e.(type) {
+ case eventUpstreamConnected:
+ uc := e.uc
+
+ uc.network.conn = uc
+
+ uc.updateAway()
+ uc.updateMonitor()
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName()))
+ }
+
+ dc.updateNick()
+ dc.updateRealname()
+ dc.updateAccount()
+ })
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=connected"},
+ })
+ }
+ })
+ uc.network.lastError = nil
+ case eventUpstreamDisconnected:
+ u.handleUpstreamDisconnected(e.uc)
+ case eventUpstreamConnectionError:
+ net := e.net
+
+ stopped := false
+ select {
+ case <-net.stopped:
+ stopped = true
+ default:
+ }
+
+ if !stopped && (net.lastError == nil || net.lastError.Error() != e.err.Error()) {
+ net.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("failed connecting/registering to %s: %v", net.GetName(), e.err))
+ })
+ }
+ net.lastError = e.err
+ case eventUpstreamError:
+ uc := e.uc
+
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", uc.network.GetName(), e.err))
+ })
+ uc.network.lastError = e.err
+ case eventUpstreamMessage:
+ msg, uc := e.msg, e.uc
+ if uc.isClosed() {
+ uc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ if err := uc.handleMessage(context.TODO(), msg); err != nil {
+ uc.logger.Printf("failed to handle message %q: %v", msg, err)
+ }
+ case eventChannelDetach:
+ uc, name := e.uc, e.name
+ c := uc.network.channels.Value(name)
+ if c == nil || c.Detached {
+ continue
+ }
+ uc.network.detach(c)
+ if err := uc.srv.db.StoreChannel(context.TODO(), uc.network.ID, c); err != nil {
+ u.logger.Printf("failed to store updated detached channel %q: %v", c.Name, err)
+ }
+ case eventDownstreamConnected:
+ dc := e.dc
+
+ if dc.network != nil {
+ dc.monitored.SetCasemapping(dc.network.casemap)
+ }
+
+ if err := dc.welcome(context.TODO()); err != nil {
+ dc.logger.Printf("failed to handle new registered connection: %v", err)
+ break
+ }
+
+ u.downstreamConns = append(u.downstreamConns, dc)
+
+ dc.forEachNetwork(func(network *network) {
+ if network.lastError != nil {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s: %v", network.GetName(), network.lastError))
+ }
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.updateAway()
+ })
+ case eventDownstreamDisconnected:
+ dc := e.dc
+
+ for i := range u.downstreamConns {
+ if u.downstreamConns[i] == dc {
+ u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
+ break
+ }
+ }
+
+ dc.forEachNetwork(func(net *network) {
+ net.storeClientDeliveryReceipts(context.TODO(), dc.clientName)
+ })
+
+ u.forEachUpstream(func(uc *upstreamConn) {
+ uc.cancelPendingCommandsByDownstreamID(dc.id)
+ uc.updateAway()
+ uc.updateMonitor()
+ })
+ case eventDownstreamMessage:
+ msg, dc := e.msg, e.dc
+ if dc.isClosed() {
+ dc.logger.Printf("ignoring message on closed connection: %v", msg)
+ break
+ }
+ err := dc.handleMessage(context.TODO(), msg)
+ if ircErr, ok := err.(ircError); ok {
+ ircErr.Message.Prefix = dc.srv.prefix()
+ dc.SendMessage(ircErr.Message)
+ } else if err != nil {
+ dc.logger.Printf("failed to handle message %q: %v", msg, err)
+ dc.Close()
+ }
+ case eventBroadcast:
+ msg := e.msg
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.SendMessage(msg)
+ })
+ case eventUserUpdate:
+ // copy the user record because we'll mutate it
+ record := u.User
+
+ if e.password != nil {
+ record.Password = *e.password
+ }
+ if e.admin != nil {
+ record.Admin = *e.admin
+ }
+
+ e.done <- u.updateUser(context.TODO(), &record)
+
+ // If the password was updated, kill all downstream connections to
+ // force them to re-authenticate with the new credentials.
+ if e.password != nil {
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ }
+ case eventStop:
+ u.forEachDownstream(func(dc *downstreamConn) {
+ dc.Close()
+ })
+ for _, n := range u.networks {
+ n.stop()
+
+ n.delivered.ForEachClient(func(clientName string) {
+ n.storeClientDeliveryReceipts(context.TODO(), clientName)
+ })
+ }
+ return
+ default:
+ panic(fmt.Sprintf("received unknown event type: %T", e))
+ }
+ }
+}
+
+func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
+ uc.network.conn = nil
+
+ uc.abortPendingCommands()
+
+ for _, entry := range uc.channels.innerMap {
+ uch := entry.value.(*upstreamChannel)
+ uch.updateAutoDetach(0)
+ }
+
+ netIDStr := fmt.Sprintf("%v", uc.network.ID)
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ dc.updateSupportedCaps()
+ })
+
+ // If the network has been removed, don't send a state change notification
+ found := false
+ for _, net := range u.networks {
+ if net == uc.network {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return
+ }
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", netIDStr, "state=disconnected"},
+ })
+ }
+ })
+
+ if uc.network.lastError == nil {
+ uc.forEachDownstream(func(dc *downstreamConn) {
+ if !dc.caps["soju.im/bouncer-networks"] {
+ sendServiceNOTICE(dc, fmt.Sprintf("disconnected from %s", uc.network.GetName()))
+ }
+ })
+ }
+}
+
+func (u *user) addNetwork(network *network) {
+ u.networks = append(u.networks, network)
+
+ sort.Slice(u.networks, func(i, j int) bool {
+ return u.networks[i].ID < u.networks[j].ID
+ })
+
+ go network.run()
+}
+
+func (u *user) removeNetwork(network *network) {
+ network.stop()
+
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.Close()
+ }
+ })
+
+ for i, net := range u.networks {
+ if net == network {
+ u.networks = append(u.networks[:i], u.networks[i+1:]...)
+ return
+ }
+ }
+
+ panic("tried to remove a non-existing network")
+}
+
+func (u *user) checkNetwork(record *Network) error {
+ url, err := record.URL()
+ if err != nil {
+ return err
+ }
+ if url.User != nil {
+ return fmt.Errorf("%v:// URL must not have username and password information", url.Scheme)
+ }
+ if url.RawQuery != "" {
+ return fmt.Errorf("%v:// URL must not have query values", url.Scheme)
+ }
+ if url.Fragment != "" {
+ return fmt.Errorf("%v:// URL must not have a fragment", url.Scheme)
+ }
+ switch url.Scheme {
+ case "ircs", "irc":
+ if url.Host == "" {
+ return fmt.Errorf("%v:// URL must have a host", url.Scheme)
+ }
+ if url.Path != "" {
+ return fmt.Errorf("%v:// URL must not have a path", url.Scheme)
+ }
+ case "irc+unix", "unix":
+ if url.Host != "" {
+ return fmt.Errorf("%v:// URL must not have a host", url.Scheme)
+ }
+ if url.Path == "" {
+ return fmt.Errorf("%v:// URL must have a path", url.Scheme)
+ }
+ default:
+ return fmt.Errorf("unknown URL scheme %q", url.Scheme)
+ }
+
+ if record.GetName() == "" {
+ return fmt.Errorf("network name cannot be empty")
+ }
+ if strings.HasPrefix(record.GetName(), "-") {
+ // Can be mixed up with flags when sending commands to the service
+ return fmt.Errorf("network name cannot start with a dash character")
+ }
+
+ for _, net := range u.networks {
+ if net.GetName() == record.GetName() && net.ID != record.ID {
+ return fmt.Errorf("a network with the name %q already exists", record.GetName())
+ }
+ }
+
+ return nil
+}
+
+func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID != 0 {
+ panic("tried creating an already-existing network")
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ if max := u.srv.Config().MaxUserNetworks; max >= 0 && len(u.networks) >= max {
+ return nil, fmt.Errorf("maximum number of networks reached")
+ }
+
+ network := newNetwork(u, record, nil)
+ err := u.srv.db.StoreNetwork(ctx, u.ID, &network.Network)
+ if err != nil {
+ return nil, err
+ }
+
+ u.addNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ attrs := getNetworkAttrs(network)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return network, nil
+}
+
+func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
+ if record.ID == 0 {
+ panic("tried updating a new network")
+ }
+
+ // If the realname is reset to the default, just wipe the per-network
+ // setting
+ if record.Realname == u.Realname {
+ record.Realname = ""
+ }
+
+ if err := u.checkNetwork(record); err != nil {
+ return nil, err
+ }
+
+ network := u.getNetworkByID(record.ID)
+ if network == nil {
+ panic("tried updating a non-existing network")
+ }
+
+ if err := u.srv.db.StoreNetwork(ctx, u.ID, record); err != nil {
+ return nil, err
+ }
+
+ // Most network changes require us to re-connect to the upstream server
+
+ channels := make([]Channel, 0, network.channels.Len())
+ for _, entry := range network.channels.innerMap {
+ ch := entry.value.(*Channel)
+ channels = append(channels, *ch)
+ }
+
+ updatedNetwork := newNetwork(u, record, channels)
+
+ // If we're currently connected, disconnect and perform the necessary
+ // bookkeeping
+ if network.conn != nil {
+ network.stop()
+ // Note: this will set network.conn to nil
+ u.handleUpstreamDisconnected(network.conn)
+ }
+
+ // Patch downstream connections to use our fresh updated network
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.network != nil && dc.network == network {
+ dc.network = updatedNetwork
+ }
+ })
+
+ // We need to remove the network after patching downstream connections,
+ // otherwise they'll get closed
+ u.removeNetwork(network)
+
+ // The filesystem message store needs to be notified whenever the network
+ // is renamed
+ fsMsgStore, isFS := u.msgStore.(*fsMessageStore)
+ if isFS && updatedNetwork.GetName() != network.GetName() {
+ if err := fsMsgStore.RenameNetwork(&network.Network, &updatedNetwork.Network); err != nil {
+ network.logger.Printf("failed to update FS message store network name to %q: %v", updatedNetwork.GetName(), err)
+ }
+ }
+
+ // This will re-connect to the upstream server
+ u.addNetwork(updatedNetwork)
+
+ // TODO: only broadcast attributes that have changed
+ idStr := fmt.Sprintf("%v", updatedNetwork.ID)
+ attrs := getNetworkAttrs(updatedNetwork)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, attrs.String()},
+ })
+ }
+ })
+
+ return updatedNetwork, nil
+}
+
+func (u *user) deleteNetwork(ctx context.Context, id int64) error {
+ network := u.getNetworkByID(id)
+ if network == nil {
+ panic("tried deleting a non-existing network")
+ }
+
+ if err := u.srv.db.DeleteNetwork(ctx, network.ID); err != nil {
+ return err
+ }
+
+ u.removeNetwork(network)
+
+ idStr := fmt.Sprintf("%v", network.ID)
+ u.forEachDownstream(func(dc *downstreamConn) {
+ if dc.caps["soju.im/bouncer-networks-notify"] {
+ dc.SendMessage(&irc.Message{
+ Prefix: dc.srv.prefix(),
+ Command: "BOUNCER",
+ Params: []string{"NETWORK", idStr, "*"},
+ })
+ }
+ })
+
+ return nil
+}
+
+func (u *user) updateUser(ctx context.Context, record *User) error {
+ if u.ID != record.ID {
+ panic("ID mismatch when updating user")
+ }
+
+ realnameUpdated := u.Realname != record.Realname
+ if err := u.srv.db.StoreUser(ctx, record); err != nil {
+ return fmt.Errorf("failed to update user %q: %v", u.Username, err)
+ }
+ u.User = *record
+
+ if realnameUpdated {
+ // Re-connect to networks which use the default realname
+ var needUpdate []Network
+ for _, net := range u.networks {
+ if net.Realname == "" {
+ needUpdate = append(needUpdate, net.Network)
+ }
+ }
+
+ var netErr error
+ for _, net := range needUpdate {
+ if _, err := u.updateNetwork(ctx, &net); err != nil {
+ netErr = err
+ }
+ }
+ if netErr != nil {
+ return netErr
+ }
+ }
+
+ return nil
+}
+
+func (u *user) stop() {
+ u.events <- eventStop{}
+ <-u.done
+}
+
+func (u *user) hasPersistentMsgStore() bool {
+ if u.msgStore == nil {
+ return false
+ }
+ _, isMem := u.msgStore.(*memoryMessageStore)
+ return !isMem
+}
+
+// localAddrForHost returns the local address to use when connecting to host.
+// A nil address is returned when the OS should automatically pick one.
+func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
+ upstreamUserIPs := u.srv.Config().UpstreamUserIPs
+ if len(upstreamUserIPs) == 0 {
+ return nil, nil
+ }
+
+ ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
+ if err != nil {
+ return nil, err
+ }
+
+ wantIPv6 := false
+ for _, ip := range ips {
+ if ip.To4() == nil {
+ wantIPv6 = true
+ break
+ }
+ }
+
+ var ipNet *net.IPNet
+ for _, in := range upstreamUserIPs {
+ if wantIPv6 == (in.IP.To4() == nil) {
+ ipNet = in
+ break
+ }
+ }
+ if ipNet == nil {
+ return nil, nil
+ }
+
+ var ipInt big.Int
+ ipInt.SetBytes(ipNet.IP)
+ ipInt.Add(&ipInt, big.NewInt(u.ID+1))
+ ip := net.IP(ipInt.Bytes())
+ if !ipNet.Contains(ip) {
+ return nil, fmt.Errorf("IP network %v too small", ipNet)
+ }
+
+ return &net.TCPAddr{IP: ip}, nil
+}
--- /dev/null
+package suika
+
+import (
+ "fmt"
+ "runtime/debug"
+ "strings"
+)
+
+const (
+ defaultVersion = "0.0.0"
+ defaultCommit = "HEAD"
+ defaultBuild = "0000-01-01:00:00+00:00"
+)
+
+var (
+ // Version is the tagged release version in the form <major>.<minor>.<patch>
+ // following semantic versioning and is overwritten by the build system.
+ Version = defaultVersion
+
+ // Commit is the commit sha of the build (normally from Git) and is overwritten
+ // by the build system.
+ Commit = defaultCommit
+
+ // Build is the date and time of the build as an RFC3339 formatted string
+ // and is overwritten by the build system.
+ Build = defaultBuild
+)
+
+// FullVersion display the full version and build
+func FullVersion() string {
+ var sb strings.Builder
+
+ isDefault := Version == defaultVersion && Commit == defaultCommit && Build == defaultBuild
+
+ if !isDefault {
+ sb.WriteString(fmt.Sprintf("%s@%s %s", Version, Commit, Build))
+ }
+
+ if info, ok := debug.ReadBuildInfo(); ok {
+ if isDefault {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Version))
+ }
+ sb.WriteString(fmt.Sprintf(" %s", info.GoVersion))
+ if info.Main.Sum != "" {
+ sb.WriteString(fmt.Sprintf(" %s", info.Main.Sum))
+ }
+ }
+
+ return sb.String()
+}