Skip to content

[REDO] Update espota.py #8797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified tools/espota.exe
Binary file not shown.
522 changes: 229 additions & 293 deletions tools/espota.py
Original file line number Diff line number Diff line change
@@ -27,12 +27,20 @@
# 2016-01-03:
# - Added more options to parser.
#
# Changes
# 2023-05-22:
# - Replaced the deprecated optparse module with argparse.
# - Adjusted the code style to conform to PEP 8 guidelines.
# - Used with statement for file handling to ensure proper resource cleanup.
# - Incorporated exception handling to catch and handle potential errors.
# - Made variable names more descriptive for better readability.
# - Introduced constants for better code maintainability.

from __future__ import print_function
import socket
import sys
import os
import optparse
import argparse
import logging
import hashlib
import random
@@ -41,313 +49,241 @@
FLASH = 0
SPIFFS = 100
AUTH = 200
PROGRESS = False
# update_progress() : Displays or updates a console progress bar
## Accepts a float between 0 and 1. Any int will be converted to a float.
## A value under 0 represents a 'halt'.
## A value at 1 or bigger represents 100%

# Constants
PROGRESS_BAR_LENGTH = 60

# update_progress(): Displays or updates a console progress bar
def update_progress(progress):
if (PROGRESS):
barLength = 60 # Modify this to change the length of the progress bar
status = ""
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = "error: progress var must be float\r\n"
if progress < 0:
progress = 0
status = "Halt...\r\n"
if progress >= 1:
progress = 1
status = "Done...\r\n"
block = int(round(barLength*progress))
text = "\rUploading: [{0}] {1}% {2}".format( "="*block + " "*(barLength-block), int(progress*100), status)
sys.stderr.write(text)
if PROGRESS:
status = ""
if isinstance(progress, int):
progress = float(progress)
if not isinstance(progress, float):
progress = 0
status = "Error: progress var must be float\r\n"
if progress < 0:
progress = 0
status = "Halt...\r\n"
if progress >= 1:
progress = 1
status = "Done...\r\n"
block = int(round(PROGRESS_BAR_LENGTH * progress))
text = "\rUploading: [{0}] {1}% {2}".format("=" * block + " " * (PROGRESS_BAR_LENGTH - block), int(progress * 100), status)
sys.stderr.write(text)
sys.stderr.flush()
else:
sys.stderr.write(".")
sys.stderr.flush()

def serve(remote_addr, local_addr, remote_port, local_port, password, filename, command=FLASH):
# Create a TCP/IP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_address = (local_addr, local_port)
logging.info('Starting on %s:%s', str(server_address[0]), str(server_address[1]))
try:
sock.bind(server_address)
sock.listen(1)
except Exception as e:
logging.error("Listen Failed: %s", str(e))
return 1

content_size = os.path.getsize(filename)
file_md5 = hashlib.md5(open(filename, 'rb').read()).hexdigest()
logging.info('Upload size: %d', content_size)
message = '%d %d %d %s\n' % (command, local_port, content_size, file_md5)

# Wait for a connection
inv_tries = 0
data = ''
msg = 'Sending invitation to %s ' % remote_addr
sys.stderr.write(msg)
sys.stderr.flush()
else:
sys.stderr.write('.')
while inv_tries < 10:
inv_tries += 1
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
remote_address = (remote_addr, int(remote_port))
try:
sent = sock2.sendto(message.encode(), remote_address)
except:
sys.stderr.write('failed\n')
sys.stderr.flush()
sock2.close()
logging.error('Host %s Not Found', remote_addr)
return 1
sock2.settimeout(TIMEOUT)
try:
data = sock2.recv(37).decode()
break
except:
sys.stderr.write('.')
sys.stderr.flush()
sock2.close()
sys.stderr.write('\n')
sys.stderr.flush()
if inv_tries == 10:
logging.error('No response from the ESP')
return 1
if data != "OK":
if data.startswith('AUTH'):
nonce = data.split()[1]
cnonce_text = '%s%u%s%s' % (filename, content_size, file_md5, remote_addr)
cnonce = hashlib.md5(cnonce_text.encode()).hexdigest()
passmd5 = hashlib.md5(password.encode()).hexdigest()
result_text = '%s:%s:%s' % (passmd5, nonce, cnonce)
result = hashlib.md5(result_text.encode()).hexdigest()
sys.stderr.write('Authenticating...')
sys.stderr.flush()
message = '%d %s %s\n' % (AUTH, cnonce, result)
sock2.sendto(message.encode(), remote_address)
sock2.settimeout(10)
try:
data = sock2.recv(32).decode()
except:
sys.stderr.write('FAIL\n')
logging.error('No Answer to our Authentication')
sock2.close()
return 1
if data != "OK":
sys.stderr.write('FAIL\n')
logging.error('%s', data)
sock2.close()
sys.exit(1)
return 1
sys.stderr.write('OK\n')
else:
logging.error('Bad Answer: %s', data)
sock2.close()
return 1
sock2.close()

def serve(remoteAddr, localAddr, remotePort, localPort, password, filename, command = FLASH):
# Create a TCP/IP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_address = (localAddr, localPort)
logging.info('Starting on %s:%s', str(server_address[0]), str(server_address[1]))
try:
sock.bind(server_address)
sock.listen(1)
except:
logging.error("Listen Failed")
return 1

content_size = os.path.getsize(filename)
f = open(filename,'rb')
file_md5 = hashlib.md5(f.read()).hexdigest()
f.close()
logging.info('Upload size: %d', content_size)
message = '%d %d %d %s\n' % (command, localPort, content_size, file_md5)

# Wait for a connection
inv_trys = 0
data = ''
msg = 'Sending invitation to %s ' % (remoteAddr)
sys.stderr.write(msg)
sys.stderr.flush()
while (inv_trys < 10):
inv_trys += 1
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
remote_address = (remoteAddr, int(remotePort))
try:
sent = sock2.sendto(message.encode(), remote_address)
except:
sys.stderr.write('failed\n')
sys.stderr.flush()
sock2.close()
logging.error('Host %s Not Found', remoteAddr)
return 1
sock2.settimeout(TIMEOUT)
logging.info('Waiting for device...')
try:
data = sock2.recv(37).decode()
break;
sock.settimeout(10)
connection, client_address = sock.accept()
sock.settimeout(None)
connection.settimeout(None)
except:
sys.stderr.write('.')
sys.stderr.flush()
sock2.close()
sys.stderr.write('\n')
sys.stderr.flush()
if (inv_trys == 10):
logging.error('No response from the ESP')
return 1
if (data != "OK"):
if(data.startswith('AUTH')):
nonce = data.split()[1]
cnonce_text = '%s%u%s%s' % (filename, content_size, file_md5, remoteAddr)
cnonce = hashlib.md5(cnonce_text.encode()).hexdigest()
passmd5 = hashlib.md5(password.encode()).hexdigest()
result_text = '%s:%s:%s' % (passmd5 ,nonce, cnonce)
result = hashlib.md5(result_text.encode()).hexdigest()
sys.stderr.write('Authenticating...')
sys.stderr.flush()
message = '%d %s %s\n' % (AUTH, cnonce, result)
sock2.sendto(message.encode(), remote_address)
sock2.settimeout(10)
try:
data = sock2.recv(32).decode()
except:
sys.stderr.write('FAIL\n')
logging.error('No Answer to our Authentication')
sock2.close()
return 1
if (data != "OK"):
sys.stderr.write('FAIL\n')
logging.error('%s', data)
sock2.close()
sys.exit(1);
logging.error('No response from device')
sock.close()
return 1
sys.stderr.write('OK\n')
else:
logging.error('Bad Answer: %s', data)
sock2.close()
return 1
sock2.close()

logging.info('Waiting for device...')
try:
sock.settimeout(10)
connection, client_address = sock.accept()
sock.settimeout(None)
connection.settimeout(None)
except:
logging.error('No response from device')
try:
with open(filename, "rb") as f:
if PROGRESS:
update_progress(0)
else:
sys.stderr.write('Uploading')
sys.stderr.flush()
offset = 0
while True:
chunk = f.read(1024)
if not chunk:
break
offset += len(chunk)
update_progress(offset / float(content_size))
connection.settimeout(10)
try:
connection.sendall(chunk)
res = connection.recv(10)
last_response_contained_ok = 'OK' in res.decode()
except Exception as e:
sys.stderr.write('\n')
logging.error('Error Uploading: %s', str(e))
connection.close()
return 1

if last_response_contained_ok:
logging.info('Success')
connection.close()
return 0

sys.stderr.write('\n')
logging.info('Waiting for result...')
count = 0
while count < 5:
count += 1
connection.settimeout(60)
try:
data = connection.recv(32).decode()
logging.info('Result: %s', data)

if "OK" in data:
logging.info('Success')
connection.close()
return 0

except Exception as e:
logging.error('Error receiving result: %s', str(e))
connection.close()
return 1

logging.error('Error response from device')
connection.close()
return 1

finally:
connection.close()

sock.close()
return 1
try:
f = open(filename, "rb")
if (PROGRESS):
update_progress(0)
else:
sys.stderr.write('Uploading')
sys.stderr.flush()
offset = 0
while True:
chunk = f.read(1024)
if not chunk: break
offset += len(chunk)
update_progress(offset/float(content_size))
connection.settimeout(10)
try:
connection.sendall(chunk)
res = connection.recv(10)
lastResponseContainedOK = 'OK' in res.decode()
except:
sys.stderr.write('\n')
logging.error('Error Uploading')
connection.close()
f.close()
sock.close()
return 1

if lastResponseContainedOK:
logging.info('Success')
connection.close()
f.close()
sock.close()
return 0

sys.stderr.write('\n')
logging.info('Waiting for result...')
try:
count = 0
while True:
count=count+1
connection.settimeout(60)
data = connection.recv(32).decode()
logging.info('Result: %s' ,data)

if "OK" in data:
logging.info('Success')
connection.close()
f.close()
sock.close()
return 0;
if count == 5:
logging.error('Error response from device')
connection.close()
f.close()
sock.close()
return 1
except e:
logging.error('No Result!')
connection.close()
f.close()
sock.close()
return 1

finally:
connection.close()
f.close()

sock.close()
return 1
# end serve


def parser(unparsed_args):
parser = optparse.OptionParser(
usage = "%prog [options]",
description = "Transmit image over the air to the esp32 module with OTA support."
)

# destination ip and port
group = optparse.OptionGroup(parser, "Destination")
group.add_option("-i", "--ip",
dest = "esp_ip",
action = "store",
help = "ESP32 IP Address.",
default = False
)
group.add_option("-I", "--host_ip",
dest = "host_ip",
action = "store",
help = "Host IP Address.",
default = "0.0.0.0"
)
group.add_option("-p", "--port",
dest = "esp_port",
type = "int",
help = "ESP32 ota Port. Default 3232",
default = 3232
)
group.add_option("-P", "--host_port",
dest = "host_port",
type = "int",
help = "Host server ota Port. Default random 10000-60000",
default = random.randint(10000,60000)
)
parser.add_option_group(group)

# auth
group = optparse.OptionGroup(parser, "Authentication")
group.add_option("-a", "--auth",
dest = "auth",
help = "Set authentication password.",
action = "store",
default = ""
)
parser.add_option_group(group)

# image
group = optparse.OptionGroup(parser, "Image")
group.add_option("-f", "--file",
dest = "image",
help = "Image file.",
metavar="FILE",
default = None
)
group.add_option("-s", "--spiffs",
dest = "spiffs",
action = "store_true",
help = "Use this option to transmit a SPIFFS image and do not flash the module.",
default = False
)
parser.add_option_group(group)

# output group
group = optparse.OptionGroup(parser, "Output")
group.add_option("-d", "--debug",
dest = "debug",
help = "Show debug output. And override loglevel with debug.",
action = "store_true",
default = False
)
group.add_option("-r", "--progress",
dest = "progress",
help = "Show progress output. Does not work for ArduinoIDE",
action = "store_true",
default = False
)
group.add_option("-t", "--timeout",
dest = "timeout",
type = "int",
help = "Timeout to wait for the ESP32 to accept invitation",
default = 10
)
parser.add_option_group(group)

(options, args) = parser.parse_args(unparsed_args)

return options
# end parser
def parse_args(unparsed_args):
parser = argparse.ArgumentParser(
description="Transmit image over the air to the ESP32 module with OTA support."
)

# destination ip and port
parser.add_argument("-i", "--ip", dest="esp_ip", action="store", help="ESP32 IP Address.", default=False)
parser.add_argument("-I", "--host_ip", dest="host_ip", action="store", help="Host IP Address.", default="0.0.0.0")
parser.add_argument("-p", "--port", dest="esp_port", type=int, help="ESP32 OTA Port. Default: 3232", default=3232)
parser.add_argument(
"-P", "--host_port", dest="host_port", type=int, help="Host server OTA Port. Default: random 10000-60000", default=random.randint(10000, 60000)
)

# authentication
parser.add_argument("-a", "--auth", dest="auth", help="Set authentication password.", action="store", default="")

# image
parser.add_argument("-f", "--file", dest="image", help="Image file.", metavar="FILE", default=None)
parser.add_argument("-s", "--spiffs", dest="spiffs", action="store_true", help="Transmit a SPIFFS image and do not flash the module.", default=False)

# output
parser.add_argument("-d", "--debug", dest="debug", action="store_true", help="Show debug output. Overrides loglevel with debug.", default=False)
parser.add_argument("-r", "--progress", dest="progress", action="store_true", help="Show progress output. Does not work for Arduino IDE.", default=False)
parser.add_argument("-t", "--timeout", dest="timeout", type=int, help="Timeout to wait for the ESP32 to accept invitation.", default=10)

return parser.parse_args(unparsed_args)


def main(args):
options = parser(args)
loglevel = logging.WARNING
if (options.debug):
loglevel = logging.DEBUG

logging.basicConfig(level = loglevel, format = '%(asctime)-8s [%(levelname)s]: %(message)s', datefmt = '%H:%M:%S')
logging.debug("Options: %s", str(options))

# check options
global PROGRESS
PROGRESS = options.progress

global TIMEOUT
TIMEOUT = options.timeout
if (not options.esp_ip or not options.image):
logging.critical("Not enough arguments.")
return 1
options = parse_args(args)
log_level = logging.WARNING
if options.debug:
log_level = logging.DEBUG

logging.basicConfig(level=log_level, format="%(asctime)-8s [%(levelname)s]: %(message)s", datefmt="%H:%M:%S")
logging.debug("Options: %s", str(options))

# check options
global PROGRESS
PROGRESS = options.progress

global TIMEOUT
TIMEOUT = options.timeout

if not options.esp_ip or not options.image:
logging.critical("Not enough arguments.")
return 1

command = FLASH
if (options.spiffs):
command = SPIFFS
command = FLASH
if options.spiffs:
command = SPIFFS

return serve(options.esp_ip, options.host_ip, options.esp_port, options.host_port, options.auth, options.image, command)
# end main
return serve(
options.esp_ip, options.host_ip, options.esp_port, options.host_port, options.auth, options.image, command
)


if __name__ == '__main__':
sys.exit(main(sys.argv))
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))