Automated expression to SSE converter

As a side product of some work I'm doing on automated circuit solving I've hacked into the Juce expression class the ability to output expressions as SSE intrinsics calls, so the class can now be used to automatically convert infix scalar expressions to prefix SSE intrinsics.

http://www.cytomic.com/files/dsp/juce_ExpressionSse.zip

I've also added a function called "balance" which tries to balance the expression tree as much as possible for parallel execution and also tries to remove divisions and replace them with multiplications. This is not a replacement for computer algegra simplification of complex expserssion, it just handles a few basic cases and does the last step of getting regular c code type expressions into SSE.

There are probably errors but it's generating useful expressions for me right now so I thought I would share it with everyone. Here are some examples of what it does:

input  = a1 + a2 + a3 + a4
output = add_pd (add_pd (a1, a2), add_pd (a3, a4))

input  = a1 + a2 + a3 - a4
output = add_pd (add_pd (a1, a2), sub_pd (a3, a4))

input  = a1 - a2 - a3 + a4
output = sub_pd (add_pd (a1, a4), add_pd (a2, a3))

input  = -a1 - a2 - a3 - a4
output = neg_pd (add_pd (add_pd (a1, a2), add_pd (a3, a4)))

input  = a1 - (-a2 + a3 - (-a4))
output = sub_pd (add_pd (a1, a2), add_pd (a3, a4))

input  = a1 + a2 + a3 + a4 + a5 + a6 + a7
output = add_pd (add_pd (add_pd (a1, a2), add_pd (a3, a4)), add_pd (add_pd (a5, a6), a7))

input  = m1*m2*m3*m4
output = mul_pd (mul_pd (m1, m2), mul_pd (m3, m4))

input  = m1*m2*m3/m4
output = mul_pd (mul_pd (m1, m2), div_pd (m3, m4))

input  = m1/m2/m3*m4
output = div_pd (mul_pd (m1, m4), mul_pd (m2, m3))

input  = m1/m2/m3/m4
output = div_pd (m1, mul_pd (mul_pd (m2, m3), m4))

input  = a1/a2/(a3/a4)
output = div_pd (mul_pd (a1, a4), mul_pd (a2, a3))

And the code that generated the above:

    StringArray exs;

    exs.add ("a1 + a2 + a3 + a4");
    exs.add ("a1 + a2 + a3 - a4");
    exs.add ("a1 - a2 - a3 + a4");
    exs.add ("-a1 - a2 - a3 - a4"); 
    exs.add ("a1 - (-a2 + a3 - (-a4))");
    exs.add ("a1 + a2 + a3 + a4 + a5 + a6 + a7");
    exs.add ("m1*m2*m3*m4");
    exs.add ("m1*m2*m3/m4");
    exs.add ("m1/m2/m3*m4");
    exs.add ("m1/m2/m3/m4");
    exs.add ("a1/a2/(a3/a4)");

    ScopedPointer <Expression::SsePdFormat> format (new Expression::SsePdFormat);
    for (int i=0; i<exs.size (); i++)
    {
        std::cout << "input  = " << exs[i] << "\n";
        std::cout << "output = " << Expression (exs[i]).basicSimplify ().balance ().toString (format) << "\n\n";
    }

I've added two calls to the expression class: basicSimplify and balance, basic simplify removes some basic identities like:

        a*1=a, a/1=a, a*-1=-a, -1/a=-a, a*0=0, 0/a=0, a+0=a, a-0=a, a-(-b)=a+b, a+(-b)=a-b, -a-(-b)=b-a

And balance tries to balance out the left and right branches of an expression to make the evaluation as parrallel as possible, which only works when the resultant string is copy and pasted into a c file and then compiled with sse intrisics.

With a bit more work this could auto convert code into sse code. I've done a very basic converter that handles very limited cases as an example to get people up and running. If you wan't to handle more complicated code as the input you'll need to code it yourself and parse things properly instead of just basic hacking based around a single equal and semicolon sign on the same line like I've done:

String basicCodeToSse (const String& code)
{
    String out;
    ScopedPointer <Expression::SsePdFormat> format (new Expression::SsePdFormat);
    StringArray lines;
    lines.addLines (code);
    for (int i=0; i<lines.size (); i++)
    {
        String line = lines[i];
        if (line.contains ("="))
        {
            String lhs = line.upToLastOccurrenceOf ("=", false, false);
            String rhs = line.substring (lhs.length ()+1).trim ();
            String toconvert = rhs.upToFirstOccurrenceOf (";", false, false);
            rhs = rhs.substring (toconvert.length ());
            lhs.trimEnd ();
            juce_wchar c = lhs[lhs.length ()-1];
            // check for  += -= *= /=
            if (c == '+' || c == '-' || c == '*' || c == '/')
            {
                String variable = lhs.trim ();
                lhs = lhs.dropLastCharacters (1);
                toconvert = variable + "("+toconvert+")";
            }
            String converted = Expression (toconvert).basicSimplify ().balance ().toString (format);
            if (lhs.contains ("double"))
            {
                lhs = lhs.replace ("double", "vec2d");
            }
            lines.set (i, lhs + "= " + converted + rhs);
        }
        out << lines[i] + "\n";
    }
    return out;
}

 

And here is an example of calling this:

    String code =
    "    // can only handle end of line comments and single line expressions\n"
    "    // can only handle a single variable before any asignments\n"
    "    // no error checking at all\n"
    "\n"
    "    const double a1 = a2*a3 - a4;\n"
    "    double a5 = a1*a3 + a4*a5 + a6*a7 - a8;\n"
    "    a5 *= a3 + a5;\n"
    "    a5 -= a2 + a3;\n";
    
    std::cout << basicCodeToSse (code);

And the output:

    // can only handle end of line comments and single line expressions
    // can only handle a single variable before any asignments
    // no error checking at all

    const vec2d a1 = sub_pd (mul_pd (a2, a3), a4);
    vec2d a5 = add_pd (add_pd (mul_pd (a1, a3), mul_pd (a4, a5)), sub_pd (mul_pd (a6, a7), a8));
    a5 = mul_pd (a5, add_pd (a3, a5));
    a5 = sub_pd (a5, add_pd (a2, a3));

I use macros to help make things a tiny bit more readable in sse code by dropping all the _mm_ from in front of things, I use these by default, but you can change it to anything:

typedef __m128d vec2d;
typedef __m128  vec4f;
#define set_pd(x0,x1)    _mm_set_pd(x1, x0)
#define set1_pd(x0)      _mm_set1_pd(x0)
#define add_pd(x0,x1)    _mm_add_pd(x0,x1)
#define sub_pd(x0,x1)    _mm_sub_pd(x0,x1)
#define mul_pd(x0,x1)    _mm_mul_pd(x0,x1)
#define div_pd(x0,x1)    _mm_div_pd(x0,x1)
#define max_pd(x0,x1)    _mm_max_pd(x0,x1)
#define min_pd(x0,x1)    _mm_min_pd(x0,x1)
#define clip_pd(x0,lo,hi) _mm_min_pd(_mm_max_pd(x0,lo),hi)
#define cvtps_pd(x0)     _mm_cvtps_pd(x0)
#define andnot_pd(x0,x1) _mm_andnot_pd(x0,x1)
#define cmplt_pd(x0,x1)  _mm_cmplt_pd(x0,x1)
#define cmple_pd(x0,x1)  _mm_cmple_pd(x0,x1)
#define cmpgt_pd(x0,x1)  _mm_cmpgt_pd(x0,x1)
#define cmpge_pd(x0,x1)  _mm_cmpge_pd(x0,x1)
#define or_pd(x0,x1)     _mm_or_pd(x0,x1)
#define xor_pd(x0,x1)    _mm_xor_pd(x0,x1)
#define and_pd(x0,x1)    _mm_and_pd(x0,x1)
#define andnot_pd(x0,x1) _mm_andnot_pd(x0,x1)
#define setzero_pd()     _mm_setzero_pd()
#define load_pd(x0)      _mm_load_pd(x0)
#define loadu_pd(x0)     _mm_loadu_pd(x0)
#define store_pd(x0,x1)  _mm_store_pd(x0,x1)
#define storeu_pd(x0,x1) _mm_storeu_pd(x0,x1)
#define cvtlps_pd(x0)    _mm_cvtps_pd(x0)
#define cvthps_pd(x0)    _mm_cvtps_pd(_mm_movehl_ps(x0,x0))
#define set_ps(x0,x1,x2,x3) _mm_set_ps(x3, x2, x1, x0)
#define set1_ps(x0)      _mm_set1_ps(x0)
#define add_ps(x0,x1)    _mm_add_ps(x0,x1)
#define sub_ps(x0,x1)    _mm_sub_ps(x0,x1)
#define mul_ps(x0,x1)    _mm_mul_ps(x0,x1)
#define div_ps(x0,x1)    _mm_div_ps(x0,x1)
#define rcp_ps(x0)       _mm_rcp_ps(x0)
#define max_ps(x0,x1)    _mm_max_ps(x0,x1)
#define min_ps(x0,x1)    _mm_min_ps(x0,x1)
#define clip_ps(x0,lo,hi) _mm_min_ps(_mm_max_ps(x0,lo),hi)
#define and_ps(x0,x1)    _mm_and_ps(x0,x1)
#define andnot_ps(x0,x1) _mm_andnot_ps(x0,x1)
#define or_ps(x0,x1)     _mm_or_ps(x0,x1)
#define xor_ps(x0,x1)    _mm_xor_ps(x0,x1)
#define movelh_ps(x0,x1) _mm_movelh_ps(x0,x1)
#define movehl_ps(x0,x1) _mm_movehl_ps(x0,x1)
#define cvtepi32_ps(x0)  _mm_cvtepi32_ps(x0)
#define cvtpd_ps(x0)     _mm_cvtpd_ps(x0)
#define cmplt_ps(x0,x1)  _mm_cmplt_ps(x0,x1)
#define cmple_ps(x0,x1)  _mm_cmple_ps(x0,x1)
#define cmpgt_ps(x0,x1)  _mm_cmpgt_ps(x0,x1)
#define cmpge_ps(x0,x1)  _mm_cmpge_ps(x0,x1)
#define setzero_ps()     _mm_setzero_ps()
#define load_ps(x0)      _mm_load_ps(x0)
#define loadu_ps(x0)     _mm_loadu_ps(x0)
#define store_ps(x0,x1)  _mm_store_ps(x0,x1)
#define storeu_ps(x0,x1) _mm_storeu_ps(x0,x1)
#define cvtpd2_ps(x0,x1) _mm_movelh_ps(_mm_cvtpd_ps(x0),_mm_cvtpd_ps(x1))
#define cvtps_epi32(x0)  _mm_cvtps_epi32(x0)
const vec2d zero_pd    = set1_pd ( 0.0);
const vec2d negzero_pd = set1_pd (-0.0);
const vec2d half_pd    = set1_pd ( 0.5);
const vec2d neghalf_pd = set1_pd (-0.5);
const vec2d one_pd     = set1_pd ( 1.0);
const vec2d negone_pd  = set1_pd (-1.0);
const vec2d two_pd     = set1_pd ( 2.0);
const vec2d negtwo_pd  = set1_pd (-2.0);
const vec2d pi_pd      = set1_pd (3.141592653589793);
const vec2d halfpi_pd  = set1_pd (0.5*3.141592653589793);
const vec2d twopi_pd   = set1_pd (2.0*3.141592653589793);
const vec4f zero_ps    = set1_ps ( 0.0f);
const vec4f negzero_ps = set1_ps (-0.0f);
const vec4f half_ps    = set1_ps ( 0.5f);
const vec4f one_ps     = set1_ps ( 1.0f);
const vec4f negone_ps  = set1_ps (-1.0f);
const vec4f two_ps     = set1_ps ( 2.0f);
const vec4f negtwo_ps  = set1_ps (-2.0f);
const vec4f pi_ps      = set1_ps (3.141592653589793f);
const vec4f halfpi_ps  = set1_ps (0.5f*3.141592653589793f);
const vec4f twopi_ps   = set1_ps (2.0f*3.141592653589793f);
#define abs_ps(x0) andnot_ps(negzero_ps, x0)
#define neg_ps(x0) xor_ps(negzero_ps, x0)
#define abs_pd(x0) andnot_pd(negzero_pd, x0)
#define neg_pd(x0) xor_pd(negzero_pd, x0)
__inline vec2d sqr_pd (vec2d x0)
{
    return mul_pd (x0, x0);
}
__inline vec4f sqr_ps (vec4f x0)
{
    return mul_ps (x0, x0);
}

Forgot to say, I've not tested the postfix code at all so beware using it (I don't even know of any practical use for it anyway!). Also appologies to Jules for destroying his carefully constructed class by adding the setInput function and making all the operators explicitly and centrally defined, but I figured this was the easiest way to maintain things and also support multiple output formatters. 

Interesting, is this for build-time code generation or for on-the-fly compilation?

Really interesting idea!

I use this for build-time code generation for circuit simulation so I can process stereo signals at no extra cost. Since the output is c code you can do what you want, on the fly just in time comilation, dynamic compilation of some code to a dynamic lib and then re-link, anything really.

There is sse2 which handles two doubles, and now avx which handles four doubles, so having a single input file as regular scalar code that gets automatically generated into both sse2 and avx will hopefully save on a lot of time and prevent errors.

There are also multiple other DSP platforms like the Avid HDX which uses Texas Instruments TMS320C67x/C67x+ DSP chips, and also the ARM Cortex stuff used by Apple, so being able to auto generate specific intrinsics to all these platforms woudl be cool, and not too much of a reach from the work I've already started.

You can as well check this

 

http://nt2.metascale.fr/doc/html/the_boost_simd_library.html

 

HTH

I've updated my code and added some more powerful algebraic simplifications. I've also started a new thread since being able to output as sse is no longer really the focus, it's just a handy extra feature: 

http://www.juce.com/forum/topic/basic-expression-simplification