summaryrefslogtreecommitdiff
path: root/NN/Scripts/NNFunctions/table_gen.py
blob: 5db6d3e151f4e7b29996fbb7815cc7269371231d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/python

import math

class Table(object):

  def __init__(self, table_entry=256, table_range=8):
    self.table_entry = table_entry
    self.table_range = table_range
    pass

  def sigmoid(self, x):
    return 1 / (1 + math.exp(-1*x))
  
  def tanh(self, x):
    return (math.exp(2*x)-1) / (math.exp(2*x)+1)
  
  def fp2q7(self, x):
    x_int = math.floor(x*(2**7)+0.5)
    if x_int >= 128 :
      x_int = 127
    if x_int < -128 :
      x_int = -128
    if x_int >= 0 :
      return x_int
    else :
      return 0x100 + x_int
  
  def fp2q15(self, x):
    x_int = math.floor(x*(2**15)+0.5)
    if x_int >= 2**15 :
      x_int = 2**15-1
    if x_int < -1*2**15 :
      x_int = -1*2**15
    if x_int >= 0 :
      return x_int
    else :
      return 0x10000 + x_int

  def table_gen(self):
    outfile = open("NNCommonTable.c", "wb")

    outfile.write("/*\n * Common tables for NN\n *\n *\n *\n *\n */\n\n#include \"arm_math.h\"\n#include \"NNCommonTable.h\"\n\n/*\n * Table for sigmoid\n */\n")
  
    for function_type in ["sigmoid", "tanh"]:
      for data_type in [7, 15]:
        out_type = "q"+str(data_type)+"_t"
        act_func = getattr(self, function_type)
        quan_func = getattr(self, 'fp2q'+str(data_type))

        # unified table
        outfile.write('const %s %sTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry) )
        for i in range(self.table_entry):
          # convert into actual value
          if i < self.table_entry/2:
            value_q7 = self.table_range * (i)
          else:
            value_q7 = self.table_range * (i - self.table_entry)

          if data_type == 7:
            #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
            outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
          else:
            #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
            outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
          if i % 8 == 7:
            outfile.write("\n")
        outfile.write("};\n\n")

      for data_type in [15]:
        out_type = "q"+str(data_type)+"_t"
        act_func = getattr(self, function_type)
        quan_func = getattr(self, 'fp2q'+str(data_type))

        # H-L tables
        outfile.write('const %s %sLTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry/2))
        for i in range(self.table_entry/2):
          # convert into actual value, max value is 16*self.table_entry/4 / 4
          # which is equivalent to self.table_entry / self.table_entry/2 = 2, i.e., 1/4 of 8
          if i < self.table_entry/4:
            value_q7 = self.table_range * i / 4
          else:
            value_q7 = self.table_range * (i - self.table_entry/2) / 4
          if data_type == 7:
            #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
            outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
          else:
            #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
            outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
          if i % 8 == 7:
            outfile.write("\n")
        outfile.write("};\n\n")

        outfile.write('const %s %sHTable_q%d[%d] = {\n' % (out_type, function_type, data_type, 3*self.table_entry/4))
        for i in range(3 * self.table_entry/4):
          # convert into actual value, tageting range (2, 8)
          if i < 3*self.table_entry/8 :
            value_q7 = self.table_range * ( i + self.table_entry/8 )
          else:
            value_q7 = self.table_range * ( i + self.table_entry/8 - self.table_entry)
          if data_type == 7:
            #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
            outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
          else:
            #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
            outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
          if i % 8 == 7:
            outfile.write("\n")
        outfile.write("};\n\n")
    
    outfile.close()
  
  
mytable = Table(table_entry=256, table_range=16)

mytable.table_gen()