001/*
002 * The MIT License (MIT)
003 *
004 * Copyright (c) 2015-2023 decimal4j (tools4j), Marco Terzer
005 *
006 * Permission is hereby granted, free of charge, to any person obtaining a copy
007 * of this software and associated documentation files (the "Software"), to deal
008 * in the Software without restriction, including without limitation the rights
009 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
010 * copies of the Software, and to permit persons to whom the Software is
011 * furnished to do so, subject to the following conditions:
012 *
013 * The above copyright notice and this permission notice shall be included in all
014 * copies or substantial portions of the Software.
015 *
016 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
017 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
018 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
019 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
020 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
021 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
022 * SOFTWARE.
023 */
024package org.decimal4j.arithmetic;
025
026import org.decimal4j.api.DecimalArithmetic;
027import org.decimal4j.scale.Scale9f;
028import org.decimal4j.scale.ScaleMetrics;
029import org.decimal4j.scale.Scales;
030import org.decimal4j.truncate.DecimalRounding;
031
032/**
033 * Provides methods to calculate squares.
034 */
035final class Square {
036
037        private static final Scale9f SCALE9F = Scale9f.INSTANCE;
038        
039        /**
040         * Value representing: <code>floor(sqrt(Long.MAX_VALUE))</code>
041         */
042        static final long SQRT_MAX_VALUE = 3037000499L;
043
044        // necessary and sufficient condition that square fits in long
045        private static final boolean doesSquareFitInLong(long uDecimal) {
046                return -SQRT_MAX_VALUE <= uDecimal & uDecimal <= SQRT_MAX_VALUE;
047        }
048
049        /**
050         * Calculates the square {@code uDecimal^2 / scaleFactor} without rounding.
051         * Overflows are silently truncated.
052         * 
053         * @param scaleMetrics
054         *            the scale metrics defining the scale
055         * @param uDecimal
056         *            the unscaled decimal value to square
057         * @return the square result without rounding
058         */
059        public static final long square(ScaleMetrics scaleMetrics, long uDecimal) {
060                if (doesSquareFitInLong(uDecimal)) {
061                        // square fits in long, just do it
062                        return scaleMetrics.divideByScaleFactor(uDecimal * uDecimal);
063                }
064                final int scale = scaleMetrics.getScale();
065                if (scale <= 9) {
066                        // use scale to split into 2 parts: i (integral) and f (fractional)
067                        // with this scale, the low order product f*f fits in a long
068                        final long i = scaleMetrics.divideByScaleFactor(uDecimal);
069                        final long f = uDecimal - scaleMetrics.multiplyByScaleFactor(i);
070                        return scaleMetrics.multiplyByScaleFactor(i * i) + ((i * f) << 1) + scaleMetrics.divideByScaleFactor(f * f);
071                } else {
072                        // use scale9 to split into 2 parts: h (high) and l (low)
073                        final ScaleMetrics scaleDiff09 = Scales.getScaleMetrics(scale - 9);
074                        final ScaleMetrics scaleDiff18 = Scales.getScaleMetrics(18 - scale);
075                        final long h = SCALE9F.divideByScaleFactor(uDecimal);
076                        final long l = uDecimal - SCALE9F.multiplyByScaleFactor(h);
077                        final long hxl = h * l;
078                        final long lxld = SCALE9F.divideByScaleFactor(l * l);
079                        final long hxld = scaleDiff09.divideByScaleFactor(hxl);
080                        final long hxlr = hxl - scaleDiff09.multiplyByScaleFactor(hxld);
081                        return scaleDiff18.multiplyByScaleFactor(h * h) + (hxld << 1)
082                                        + scaleDiff09.divideByScaleFactor((hxlr << 1) + lxld);
083                }
084        }
085
086        /**
087         * Calculates the square {@code uDecimal^2 / scaleFactor} applying the
088         * specified rounding for truncated decimals. Overflows are silently
089         * truncated.
090         * 
091         * @param scaleMetrics
092         *            the scale metrics defining the scale
093         * @param rounding
094         *            the rounding to apply for truncated decimals
095         * @param uDecimal
096         *            the unscaled decimal value to square
097         * @return the square result with rounding
098         */
099        public static final long square(ScaleMetrics scaleMetrics, DecimalRounding rounding, long uDecimal) {
100                if (doesSquareFitInLong(uDecimal)) {
101                        // square fits in long, just do it
102                        return square32(scaleMetrics, rounding, uDecimal);
103                }
104                final int scale = scaleMetrics.getScale();
105                if (scale <= 9) {
106                        // use scale to split into 2 parts: i (integral) and f (fractional)
107                        // with this scale, the low order product f*f fits in a long
108                        final long i = scaleMetrics.divideByScaleFactor(uDecimal);
109                        final long f = uDecimal - scaleMetrics.multiplyByScaleFactor(i);
110                        final long fxf = f * f;
111                        final long fxfd = scaleMetrics.divideByScaleFactor(fxf);
112                        final long fxfr = fxf - scaleMetrics.multiplyByScaleFactor(fxfd);
113                        final long unrounded = scaleMetrics.multiplyByScaleFactor(i * i) + ((i * f) << 1) + fxfd;
114                        return unrounded
115                                        + Rounding.calculateRoundingIncrement(rounding, unrounded, fxfr, scaleMetrics.getScaleFactor());
116                } else {
117                        // use scale9 to split into 2 parts: h (high) and l (low)
118                        final ScaleMetrics scaleDiff09 = Scales.getScaleMetrics(scale - 9);
119                        final ScaleMetrics scaleDiff18 = Scales.getScaleMetrics(18 - scale);
120                        final long h = SCALE9F.divideByScaleFactor(uDecimal);
121                        final long l = uDecimal - SCALE9F.multiplyByScaleFactor(h);
122                        final long hxl = h * l;
123                        final long lxl = l * l;
124                        final long lxld = SCALE9F.divideByScaleFactor(lxl);
125                        final long hxld = scaleDiff09.divideByScaleFactor(hxl);
126                        final long hxlr = hxl - scaleDiff09.multiplyByScaleFactor(hxld);
127                        final long lxlr = lxl - SCALE9F.multiplyByScaleFactor(lxld);
128                        final long hxlx2_lxl = (hxlr << 1) + lxld;
129                        final long hxlx2_lxld = scaleDiff09.divideByScaleFactor(hxlx2_lxl);
130                        final long hxlx2_lxlr = hxlx2_lxl - scaleDiff09.multiplyByScaleFactor(hxlx2_lxld);
131                        final long unrounded = scaleDiff18.multiplyByScaleFactor(h * h) + (hxld << 1) + hxlx2_lxld;
132                        final long remainder = SCALE9F.multiplyByScaleFactor(hxlx2_lxlr) + lxlr;
133                        return unrounded + Rounding.calculateRoundingIncrement(rounding, unrounded, remainder,
134                                        scaleMetrics.getScaleFactor());
135                }
136        }
137
138        // PRECONDITION: uDecimal <= SQRT_MAX_VALUE
139        private static final long square32(ScaleMetrics scaleMetrics, DecimalRounding rounding, long uDecimal) {
140                final long u2 = uDecimal * uDecimal;
141                final long u2d = scaleMetrics.divideByScaleFactor(u2);
142                final long u2r = u2 - scaleMetrics.multiplyByScaleFactor(u2d);
143                return u2d + Rounding.calculateRoundingIncrement(rounding, u2d, u2r, scaleMetrics.getScaleFactor());
144        }
145
146        /**
147         * Calculates the square {@code uDecimal^2 / scaleFactor} truncating the
148         * result if necessary. Throws an exception if an overflow occurs.
149         * 
150         * @param arith
151         *            the arithmetic associated with the value
152         * @param uDecimal
153         *            the unscaled decimal value to square
154         * @return the square result without rounding
155         */
156        public static final long squareChecked(DecimalArithmetic arith, long uDecimal) {
157                final ScaleMetrics scaleMetrics = arith.getScaleMetrics();
158                if (doesSquareFitInLong(uDecimal)) {
159                        // square fits in long, just do it
160                        return scaleMetrics.divideByScaleFactor(uDecimal * uDecimal);
161                }
162                final int scale = scaleMetrics.getScale();
163                try {
164                        if (scale <= 9) {
165                                // use scale to split into 2 parts: i (integral) and f
166                                // (fractional)
167                                // with this scale, the low order product f*f fits in a long
168                                final long i = scaleMetrics.divideByScaleFactor(uDecimal);
169                                final long f = uDecimal - scaleMetrics.multiplyByScaleFactor(i);
170                                final long ixi = Checked.multiplyLong(i, i);// checked
171                                final long ixf = i * f;// cannot overflow
172                                final long fxf = scaleMetrics.divideByScaleFactor(f * f);// unchecked:ok
173                                // check whether we can multiply ixf by 2
174                                if (ixf < 0)
175                                        throw new ArithmeticException("Overflow: " + ixf + "<<1");
176                                final long ixfx2 = ixf << 1;
177                                // add it all up now, every operation checked
178                                long result = scaleMetrics.multiplyByScaleFactorExact(ixi);
179                                result = Checked.addLong(result, ixfx2);
180                                result = Checked.addLong(result, fxf);
181                                return result;
182                        } else {
183                                // use scale9 to split into 2 parts: h (high) and l (low)
184                                final ScaleMetrics scaleDiff09 = Scales.getScaleMetrics(scale - 9);
185                                final ScaleMetrics scaleDiff18 = Scales.getScaleMetrics(18 - scale);
186                                final long h = SCALE9F.divideByScaleFactor(uDecimal);
187                                final long l = uDecimal - SCALE9F.multiplyByScaleFactor(h);
188
189                                final long hxh = Checked.multiplyLong(h, h);// checked
190                                final long hxl = h * l;// cannot overflow
191                                final long lxld = SCALE9F.divideByScaleFactor(l * l);// unchecked:ok
192                                final long hxld = scaleDiff09.divideByScaleFactor(hxl);
193                                final long hxlr = hxl - scaleDiff09.multiplyByScaleFactor(hxld);
194                                // check whether we can multiply hxld by 2
195                                if (hxld < 0)
196                                        throw new ArithmeticException("Overflow: " + hxld + "<<1");
197                                final long hxldx2 = hxld << 1;
198                                // add it all up now, every operation checked
199                                long result = scaleDiff18.multiplyByScaleFactorExact(hxh);
200                                result = Checked.addLong(result, hxldx2);
201                                result = Checked.addLong(result, scaleDiff09.divideByScaleFactor((hxlr << 1) + lxld));
202                                return result;
203                        }
204                } catch (ArithmeticException e) {
205                        throw Exceptions.newArithmeticExceptionWithCause("Overflow: " + arith.toString(uDecimal) + "^2", e);
206                }
207        }
208
209        /**
210         * Calculates the square {@code uDecimal^2 / scaleFactor} applying the
211         * specified rounding for truncated decimals. Throws an exception if an
212         * overflow occurs.
213         * 
214         * @param arith
215         *            the arithmetic associated with the value
216         * @param rounding
217         *            the rounding to apply for truncated decimals
218         * @param uDecimal
219         *            the unscaled decimal value to square
220         * @return the square result with rounding
221         */
222        public static final long squareChecked(DecimalArithmetic arith, DecimalRounding rounding, long uDecimal) {
223                final ScaleMetrics scaleMetrics = arith.getScaleMetrics();
224                if (doesSquareFitInLong(uDecimal)) {
225                        // square fits in long, just do it
226                        return square32(scaleMetrics, rounding, uDecimal);
227                }
228                try {
229                        final int scale = scaleMetrics.getScale();
230                        if (scale <= 9) {
231                                // use scale to split into 2 parts: i (integral) and f
232                                // (fractional)
233                                final long i = scaleMetrics.divideByScaleFactor(uDecimal);
234                                final long f = uDecimal - scaleMetrics.multiplyByScaleFactor(i);
235
236                                final long ixi = Checked.multiplyLong(i, i);
237                                final long fxf = f * f;// low order product f*f fits in a long
238                                final long ixf = i * f;// cannot overflow
239                                // check whether we can multiply ixf by 2
240                                if (ixf < 0)
241                                        throw new ArithmeticException("Overflow: " + ixf + "<<1");
242                                final long ixfx2 = ixf << 1;
243
244                                final long fxfd = scaleMetrics.divideByScaleFactor(fxf);
245                                final long fxfr = fxf - scaleMetrics.multiplyByScaleFactor(fxfd);
246
247                                // add it all up now, every operation checked
248                                long unrounded = scaleMetrics.multiplyByScaleFactorExact(ixi);
249                                unrounded = Checked.addLong(unrounded, ixfx2);
250                                unrounded = Checked.addLong(unrounded, fxfd);
251                                return Checked.addLong(unrounded,
252                                                Rounding.calculateRoundingIncrement(rounding, unrounded, fxfr, scaleMetrics.getScaleFactor()));
253                        } else {
254                                // use scale9 to split into 2 parts: h (high) and l (low)
255                                final ScaleMetrics scaleDiff09 = Scales.getScaleMetrics(scale - 9);
256                                final ScaleMetrics scaleDiff18 = Scales.getScaleMetrics(18 - scale);
257                                final long h = SCALE9F.divideByScaleFactor(uDecimal);
258                                final long l = uDecimal - SCALE9F.multiplyByScaleFactor(h);
259
260                                final long hxh = Checked.multiplyLong(h, h);
261                                final long hxl = h * l;// cannot overflow
262
263                                final long hxld = scaleDiff09.divideByScaleFactor(hxl);
264                                final long hxlr = hxl - scaleDiff09.multiplyByScaleFactor(hxld);
265                                final long hxldx2 = hxld << 1;// cannot overflow
266
267                                final long lxl = l * l;// cannot overflow
268                                final long lxld = SCALE9F.divideByScaleFactor(lxl);
269                                final long lxlr = lxl - SCALE9F.multiplyByScaleFactor(lxld);
270
271                                final long hxlx2_lxl = (hxlr << 1) + lxld;// cannot overflow
272                                final long hxlx2_lxld = scaleDiff09.divideByScaleFactor(hxlx2_lxl);
273                                final long hxlx2_lxlr = hxlx2_lxl - scaleDiff09.multiplyByScaleFactor(hxlx2_lxld);
274
275                                // add it all up now, every operation checked
276                                long unrounded = scaleDiff18.multiplyByScaleFactorExact(hxh);
277                                unrounded = Checked.addLong(unrounded, hxldx2);
278                                unrounded = Checked.addLong(unrounded, hxlx2_lxld);
279                                final long remainder = SCALE9F.multiplyByScaleFactor(hxlx2_lxlr) + lxlr;// cannot
280                                                                                                                                                                                // overflow
281                                return Checked.addLong(unrounded, Rounding.calculateRoundingIncrement(rounding, unrounded, remainder,
282                                                scaleMetrics.getScaleFactor()));
283                        }
284                } catch (ArithmeticException e) {
285                        Exceptions.rethrowIfRoundingNecessary(e);
286                        throw Exceptions.newArithmeticExceptionWithCause("Overflow: " + arith.toString(uDecimal) + "^2", e);
287                }
288        }
289
290        // no instances
291        private Square() {
292        }
293}