OpenCores
URL https://opencores.org/ocsvn/keras_to_fpga/keras_to_fpga/trunk

Subversion Repositories keras_to_fpga

[/] [keras_to_fpga/] [trunk/] [src/] [xor/] [xor.py] - Blame information for rev 2

Details | Compare with Previous | View Log

Line No. Rev Author Line
1 2 qaztronic
from keras.models import Model
2
from tensorflow.keras.models import load_model
3
import numpy as np
4
import os
5
import struct
6
import shutil
7
 
8
# -------------------------------------------------------
9
model = load_model('xor.h5')
10
model.summary()
11
 
12
# -------------------------------------------------------
13
x = np.array([[0,0],[0,1],[1,0],[1,1]])
14
# print(model.predict(x).round())
15
# print(model.predict(x))
16
# print('-' * 60)
17
 
18
def float_to_hex(f):
19
    # return hex(struct.unpack('<I', struct.pack('<f', f))[0])
20
    return format(struct.unpack('<I', struct.pack('<f', f))[0], 'x')
21
 
22
# -------------------------------------------------------
23
dir = 'weights'
24
if os.path.exists(dir):
25
    shutil.rmtree(dir)
26
os.makedirs(dir)
27
 
28
for i in range(len(model.layers)):
29
    layer = model.layers[i]
30
    print(layer.name)
31
    w = layer.get_weights()
32
    # for w in layer.get_weights():
33
      # print(w.shape)
34
      # print(w)
35
      # print('^' * 60)
36
    print(w[0].shape, w[1].shape)
37
    print('+' * 60)
38
 
39
    for y in range(0, w[0].shape[1]):
40
      print('-' * 60)
41
      file_name = dir + '/' + layer.name + '_' + str(y) + '.txt'
42
      print(file_name)
43
      with open(file_name, "w") as text_file:
44
        for x in range(0, w[0].shape[0]):
45
          # print(float_to_hex(w[0][x][0]))
46
          print(float_to_hex(w[0][x][y]), file=text_file)
47
        print(float_to_hex(w[1][y]), file=text_file)
48
      # close(text_file)
49
      print('^' * 60)
50
 
51
    # print(layer.get_weights())
52
 
53
    # a = layer.get_weights()
54
    # bias = a[1]
55
    # w = np.concatenate((a[0], bias[np.newaxis,:]), axis=0)
56
    # # w.astype('float32').tofile(layer.name)
57
    # print('+' * 60)
58
 
59
    # for x in range(0, w.shape[0]):
60
      # for y in range(0, w.shape[1]):        
61
        # print(float_to_hex(w[x,y]))
62
 
63
    # print('+' * 60)
64
 
65
 
66
 

powered by: WebSVN 2.1.0

© copyright 1999-2024 OpenCores.org, equivalent to Oliscience, all rights reserved. OpenCores®, registered trademark.